-
Notifications
You must be signed in to change notification settings - Fork 433
feat: add Pandas DataFrame support to TabularDataset #1185
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
f5785dd
ae01b66
e2de699
1983471
80b2ef4
19d8565
0f81b3d
c6b4ed9
833d9f5
87ac7f9
791ca84
a43b923
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,6 @@ | ||
| # -*- coding: utf-8 -*- | ||
|
|
||
| # Copyright 2020 Google LLC | ||
| # Copyright 2022 Google LLC | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
|
|
@@ -19,12 +19,18 @@ | |
|
|
||
| from google.auth import credentials as auth_credentials | ||
|
|
||
| from google.cloud import bigquery | ||
| from google.cloud.aiplatform import base | ||
| from google.cloud.aiplatform import datasets | ||
| from google.cloud.aiplatform.datasets import _datasources | ||
| from google.cloud.aiplatform import initializer | ||
| from google.cloud.aiplatform import schema | ||
| from google.cloud.aiplatform import utils | ||
|
|
||
| _AUTOML_TRAINING_MIN_ROWS = 1000 | ||
|
|
||
| _LOGGER = base.Logger(__name__) | ||
|
|
||
|
|
||
| class TabularDataset(datasets._ColumnNamesDataset): | ||
| """Managed tabular dataset resource for Vertex AI.""" | ||
|
|
@@ -146,6 +152,112 @@ def create( | |
| create_request_timeout=create_request_timeout, | ||
| ) | ||
|
|
||
| @classmethod | ||
| def create_from_dataframe( | ||
| cls, | ||
| df_source: "pd.DataFrame", # noqa: F821 - skip check for undefined name 'pd' | ||
| staging_path: str, | ||
| bq_schema: Optional[Union[str, bigquery.SchemaField]] = None, | ||
| display_name: Optional[str] = None, | ||
| project: Optional[str] = None, | ||
| location: Optional[str] = None, | ||
| credentials: Optional[auth_credentials.Credentials] = None, | ||
| ) -> "TabularDataset": | ||
| """Creates a new tabular dataset from a Pandas DataFrame. | ||
|
|
||
| Args: | ||
| df_source (pd.DataFrame): | ||
| Required. Pandas DataFrame containing the source data for | ||
| ingestion as a TabularDataset. This method will use the data | ||
| types from the provided DataFrame when creating the dataset. | ||
| staging_path (str): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems the location requirements should also be documented or a reference to this documentation should be provided: https://cloud.google.com/vertex-ai/docs/general/locations#bq-locations If possible they should be validated but not a hard requirement. Is it possible for the dataset create to fail because of the regional requirements?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point, I just tested it and it can fail if the dataset location doesn't match the project location or the service doesn't have the right access to the dataset. I'll update the docstring to link to that page. In terms of validating, the BQ client throws this error: google.api_core.exceptions.FailedPrecondition: 400 BigQuery Dataset location Do you think we should validate as well or let the BQ client handle validation? If we do validation, we'd need to use the BQ client to check the location of the provided BQ dataset string.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agree with relying on BQ client. |
||
| Required. The BigQuery table to stage the data | ||
| for Vertex. Because Vertex maintains a reference to this source | ||
| to create the Vertex Dataset, this BigQuery table should | ||
| not be deleted. Example: `bq://my-project.my-dataset.my-table`. | ||
| If the provided BigQuery table doesn't exist, this method will | ||
| create the table. If the provided BigQuery table already exists, | ||
| and the schemas of the BigQuery table and your DataFrame match, | ||
| this method will append the data in your local DataFrame to the table. | ||
| The location of the provided BigQuery table should conform to the location requirements | ||
| specified here: https://cloud.google.com/vertex-ai/docs/general/locations#bq-locations. | ||
| bq_schema (Optional[Union[str, bigquery.SchemaField]]): | ||
| Optional. If not set, BigQuery will autodetect the schema using your DataFrame's column types. | ||
| If set, BigQuery will use the schema you provide when creating the staging table. For more details, | ||
| see: https://cloud.google.com/python/docs/reference/bigquery/latest/google.cloud.bigquery.job.LoadJobConfig#google_cloud_bigquery_job_LoadJobConfig_schema | ||
| display_name (str): | ||
| Optional. The user-defined name of the Dataset. | ||
| The name can be up to 128 characters long and can be consist | ||
| of any UTF-8 charact | ||
| project (str): | ||
| Optional. Project to upload this dataset to. Overrides project set in | ||
| aiplatform.init. | ||
| location (str): | ||
| Optional. Location to upload this dataset to. Overrides location set in | ||
| aiplatform.init. | ||
| credentials (auth_credentials.Credentials): | ||
| Optional. Custom credentials to use to upload this dataset. Overrides | ||
| credentials set in aiplatform.init. | ||
| Returns: | ||
| tabular_dataset (TabularDataset): | ||
| Instantiated representation of the managed tabular dataset resource. | ||
| """ | ||
|
|
||
| if staging_path.startswith("bq://"): | ||
| bq_staging_path = staging_path[len("bq://") :] | ||
| else: | ||
| raise ValueError( | ||
| "Only BigQuery staging paths are supported. Provide a staging path in the format `bq://your-project.your-dataset.your-table`." | ||
| ) | ||
|
|
||
| try: | ||
| import pyarrow # noqa: F401 - skip check for 'pyarrow' which is required when using 'google.cloud.bigquery' | ||
| except ImportError: | ||
| raise ImportError( | ||
| "Pyarrow is not installed, and is required to use the BigQuery client." | ||
| 'Please install the SDK using "pip install google-cloud-aiplatform[datasets]"' | ||
| ) | ||
|
|
||
| if len(df_source) < _AUTOML_TRAINING_MIN_ROWS: | ||
| _LOGGER.info( | ||
| "Your DataFrame has %s rows and AutoML requires %s rows to train on tabular data. You can still train a custom model once your dataset has been uploaded to Vertex, but you will not be able to use AutoML for training." | ||
| % (len(df_source), _AUTOML_TRAINING_MIN_ROWS), | ||
| ) | ||
|
|
||
| bigquery_client = bigquery.Client( | ||
| project=project or initializer.global_config.project, | ||
| credentials=credentials or initializer.global_config.credentials, | ||
| ) | ||
|
|
||
| try: | ||
| parquet_options = bigquery.format_options.ParquetOptions() | ||
| parquet_options.enable_list_inference = True | ||
|
|
||
| job_config = bigquery.LoadJobConfig( | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will this config infer all the types? I see the enable_list_inference but I couldn't find a reference in the BQ docs for non list type inference.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, this will infer the data types from the DF. From the BQ docs: I added a
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we can rely on BQ client validation. |
||
| source_format=bigquery.SourceFormat.PARQUET, | ||
| parquet_options=parquet_options, | ||
| ) | ||
|
|
||
| if bq_schema: | ||
| job_config.schema = bq_schema | ||
|
|
||
| job = bigquery_client.load_table_from_dataframe( | ||
| dataframe=df_source, destination=bq_staging_path, job_config=job_config | ||
| ) | ||
|
|
||
| job.result() | ||
|
|
||
| finally: | ||
| dataset_from_dataframe = cls.create( | ||
| display_name=display_name, | ||
| bq_source=staging_path, | ||
| project=project, | ||
| location=location, | ||
| credentials=credentials, | ||
| ) | ||
|
|
||
| return dataset_from_dataframe | ||
|
|
||
| def import_data(self): | ||
| raise NotImplementedError( | ||
| f"{self.__class__.__name__} class does not support 'import_data'" | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,10 +20,14 @@ | |
| import pytest | ||
| import importlib | ||
|
|
||
| import pandas as pd | ||
|
|
||
| from google import auth as google_auth | ||
| from google.api_core import exceptions | ||
| from google.api_core import client_options | ||
|
|
||
| from google.cloud import bigquery | ||
|
|
||
| from google.cloud import aiplatform | ||
| from google.cloud import storage | ||
| from google.cloud.aiplatform import utils | ||
|
|
@@ -33,6 +37,8 @@ | |
|
|
||
| from test_utils.vpcsc_config import vpcsc_config | ||
|
|
||
| from tests.system.aiplatform import e2e_base | ||
|
|
||
| # TODO(vinnys): Replace with env var `BUILD_SPECIFIC_GCP_PROJECT` once supported | ||
| _, _TEST_PROJECT = google_auth.default() | ||
| TEST_BUCKET = os.environ.get( | ||
|
|
@@ -55,40 +61,91 @@ | |
| _TEST_TEXT_ENTITY_IMPORT_SCHEMA = "gs://google-cloud-aiplatform/schema/dataset/ioformat/text_extraction_io_format_1.0.0.yaml" | ||
| _TEST_IMAGE_OBJ_DET_IMPORT_SCHEMA = "gs://google-cloud-aiplatform/schema/dataset/ioformat/image_bounding_box_io_format_1.0.0.yaml" | ||
|
|
||
| # create_from_dataframe | ||
| _TEST_BOOL_COL = "bool_col" | ||
| _TEST_BOOL_ARR_COL = "bool_array_col" | ||
| _TEST_DOUBLE_COL = "double_col" | ||
| _TEST_DOUBLE_ARR_COL = "double_array_col" | ||
| _TEST_INT_COL = "int64_col" | ||
| _TEST_INT_ARR_COL = "int64_array_col" | ||
| _TEST_STR_COL = "string_col" | ||
| _TEST_STR_ARR_COL = "string_array_col" | ||
| _TEST_BYTES_COL = "bytes_col" | ||
| _TEST_DF_COLUMN_NAMES = [ | ||
| _TEST_BOOL_COL, | ||
| _TEST_BOOL_ARR_COL, | ||
| _TEST_DOUBLE_COL, | ||
| _TEST_DOUBLE_ARR_COL, | ||
| _TEST_INT_COL, | ||
| _TEST_INT_ARR_COL, | ||
| _TEST_STR_COL, | ||
| _TEST_STR_ARR_COL, | ||
| _TEST_BYTES_COL, | ||
| ] | ||
| _TEST_DATAFRAME = pd.DataFrame( | ||
| data=[ | ||
| [ | ||
| False, | ||
| [True, False], | ||
| 1.2, | ||
| [1.2, 3.4], | ||
| 1, | ||
| [1, 2], | ||
| "test", | ||
| ["test1", "test2"], | ||
| b"1", | ||
| ], | ||
| [ | ||
| True, | ||
| [True, True], | ||
| 2.2, | ||
| [2.2, 4.4], | ||
| 2, | ||
| [2, 3], | ||
| "test1", | ||
| ["test2", "test3"], | ||
| b"0", | ||
| ], | ||
| ], | ||
| columns=_TEST_DF_COLUMN_NAMES, | ||
| ) | ||
| _TEST_DATAFRAME_BQ_SCHEMA = [ | ||
| bigquery.SchemaField(name="bool_col", field_type="BOOL"), | ||
| bigquery.SchemaField(name="bool_array_col", field_type="BOOL", mode="REPEATED"), | ||
| bigquery.SchemaField(name="double_col", field_type="FLOAT"), | ||
| bigquery.SchemaField(name="double_array_col", field_type="FLOAT", mode="REPEATED"), | ||
| bigquery.SchemaField(name="int64_col", field_type="INTEGER"), | ||
| bigquery.SchemaField(name="int64_array_col", field_type="INTEGER", mode="REPEATED"), | ||
| bigquery.SchemaField(name="string_col", field_type="STRING"), | ||
| bigquery.SchemaField(name="string_array_col", field_type="STRING", mode="REPEATED"), | ||
| bigquery.SchemaField(name="bytes_col", field_type="STRING"), | ||
| ] | ||
|
|
||
|
|
||
| @pytest.mark.usefixtures( | ||
| "prepare_staging_bucket", | ||
| "delete_staging_bucket", | ||
| "prepare_bigquery_dataset", | ||
| "delete_bigquery_dataset", | ||
| "tear_down_resources", | ||
| ) | ||
| class TestDataset(e2e_base.TestEndToEnd): | ||
|
|
||
| _temp_prefix = "temp-vertex-sdk-dataset-test" | ||
|
|
||
| class TestDataset: | ||
| def setup_method(self): | ||
| importlib.reload(initializer) | ||
| importlib.reload(aiplatform) | ||
|
|
||
| @pytest.fixture() | ||
| def shared_state(self): | ||
| shared_state = {} | ||
| yield shared_state | ||
|
|
||
| @pytest.fixture() | ||
| def create_staging_bucket(self, shared_state): | ||
| new_staging_bucket = f"temp-sdk-integration-{uuid.uuid4()}" | ||
|
|
||
| storage_client = storage.Client() | ||
| storage_client.create_bucket(new_staging_bucket) | ||
| shared_state["storage_client"] = storage_client | ||
| shared_state["staging_bucket"] = new_staging_bucket | ||
| yield | ||
|
|
||
| @pytest.fixture() | ||
| def delete_staging_bucket(self, shared_state): | ||
| yield | ||
| storage_client = shared_state["storage_client"] | ||
|
|
||
| # Delete temp staging bucket | ||
| bucket_to_delete = storage_client.get_bucket(shared_state["staging_bucket"]) | ||
| bucket_to_delete.delete(force=True) | ||
|
|
||
| # Close Storage Client | ||
| storage_client._http._auth_request.session.close() | ||
| storage_client._http.close() | ||
|
|
||
| @pytest.fixture() | ||
| def dataset_gapic_client(self): | ||
| gapic_client = dataset_service.DatasetServiceClient( | ||
|
|
@@ -253,6 +310,74 @@ def test_create_tabular_dataset(self, dataset_gapic_client, shared_state): | |
| == aiplatform.schema.dataset.metadata.tabular | ||
| ) | ||
|
|
||
| @pytest.mark.usefixtures("delete_new_dataset") | ||
| def test_create_tabular_dataset_from_dataframe( | ||
| self, dataset_gapic_client, shared_state | ||
| ): | ||
| """Use the Dataset.create_from_dataframe() method to create a new tabular dataset. | ||
| Then confirm the dataset was successfully created and references the BQ source.""" | ||
|
|
||
| assert shared_state["bigquery_dataset"] | ||
|
|
||
| shared_state["resources"] = [] | ||
|
|
||
| bigquery_dataset_id = shared_state["bigquery_dataset_id"] | ||
| bq_staging_table = f"bq://{bigquery_dataset_id}.test_table{uuid.uuid4()}" | ||
|
|
||
| aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) | ||
|
|
||
| tabular_dataset = aiplatform.TabularDataset.create_from_dataframe( | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. tabular_dataset should be appended to see example: https://github.com/googleapis/python-aiplatform/blob/main/tests/system/aiplatform/test_e2e_tabular.py#L89
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updated both new tests to use |
||
| df_source=_TEST_DATAFRAME, | ||
| staging_path=bq_staging_table, | ||
| display_name=f"temp_sdk_integration_create_and_import_dataset_from_dataframe{uuid.uuid4()}", | ||
| ) | ||
| shared_state["resources"].extend([tabular_dataset]) | ||
| shared_state["dataset_name"] = tabular_dataset.resource_name | ||
|
|
||
| gapic_metadata = tabular_dataset.to_dict()["metadata"] | ||
| bq_source = gapic_metadata["inputConfig"]["bigquerySource"]["uri"] | ||
|
|
||
| assert bq_staging_table == bq_source | ||
| assert ( | ||
| tabular_dataset.metadata_schema_uri | ||
| == aiplatform.schema.dataset.metadata.tabular | ||
| ) | ||
|
|
||
| @pytest.mark.usefixtures("delete_new_dataset") | ||
| def test_create_tabular_dataset_from_dataframe_with_provided_schema( | ||
| self, dataset_gapic_client, shared_state | ||
| ): | ||
| """Use the Dataset.create_from_dataframe() method to create a new tabular dataset, | ||
| passing in the optional `bq_schema` argument. Then confirm the dataset was successfully | ||
| created and references the BQ source.""" | ||
|
|
||
| assert shared_state["bigquery_dataset"] | ||
|
|
||
| shared_state["resources"] = [] | ||
|
|
||
| bigquery_dataset_id = shared_state["bigquery_dataset_id"] | ||
| bq_staging_table = f"bq://{bigquery_dataset_id}.test_table{uuid.uuid4()}" | ||
|
|
||
| aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) | ||
|
|
||
| tabular_dataset = aiplatform.TabularDataset.create_from_dataframe( | ||
| df_source=_TEST_DATAFRAME, | ||
| staging_path=bq_staging_table, | ||
| display_name=f"temp_sdk_integration_create_and_import_dataset_from_dataframe{uuid.uuid4()}", | ||
| bq_schema=_TEST_DATAFRAME_BQ_SCHEMA, | ||
| ) | ||
| shared_state["resources"].extend([tabular_dataset]) | ||
| shared_state["dataset_name"] = tabular_dataset.resource_name | ||
|
|
||
| gapic_metadata = tabular_dataset.to_dict()["metadata"] | ||
| bq_source = gapic_metadata["inputConfig"]["bigquerySource"]["uri"] | ||
|
|
||
| assert bq_staging_table == bq_source | ||
| assert ( | ||
| tabular_dataset.metadata_schema_uri | ||
| == aiplatform.schema.dataset.metadata.tabular | ||
| ) | ||
|
|
||
| # TODO(vinnys): Remove pytest skip once persistent resources are accessible | ||
| @pytest.mark.skip(reason="System tests cannot access persistent test resources") | ||
| @pytest.mark.usefixtures("create_staging_bucket", "delete_staging_bucket") | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.