Skip to content

Commit b78714f

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Add progress bar for generating inference.
PiperOrigin-RevId: 663828395
1 parent 3974aec commit b78714f

File tree

1 file changed

+22
-15
lines changed

1 file changed

+22
-15
lines changed

vertexai/preview/evaluation/_evaluation.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -324,28 +324,34 @@ def _generate_response_from_gemini_model(
324324
constants.Dataset.COMPLETED_PROMPT_COLUMN
325325
in evaluation_run_config.dataset.columns
326326
):
327-
with futures.ThreadPoolExecutor(max_workers=constants.MAX_WORKERS) as executor:
328-
for _, row in df.iterrows():
329-
tasks.append(
330-
executor.submit(
327+
with tqdm(total=len(df)) as pbar:
328+
with futures.ThreadPoolExecutor(
329+
max_workers=constants.MAX_WORKERS
330+
) as executor:
331+
for _, row in df.iterrows():
332+
task = executor.submit(
331333
_generate_response_from_gemini,
332334
prompt=row[constants.Dataset.COMPLETED_PROMPT_COLUMN],
333335
model=model,
334336
)
335-
)
337+
task.add_done_callback(lambda _: pbar.update(1))
338+
tasks.append(task)
336339
else:
337340
content_column_name = evaluation_run_config.column_map[
338341
constants.Dataset.CONTENT_COLUMN
339342
]
340-
with futures.ThreadPoolExecutor(max_workers=constants.MAX_WORKERS) as executor:
341-
for _, row in df.iterrows():
342-
tasks.append(
343-
executor.submit(
343+
with tqdm(total=len(df)) as pbar:
344+
with futures.ThreadPoolExecutor(
345+
max_workers=constants.MAX_WORKERS
346+
) as executor:
347+
for _, row in df.iterrows():
348+
task = executor.submit(
344349
_generate_response_from_gemini,
345350
prompt=row[content_column_name],
346351
model=model,
347352
)
348-
)
353+
task.add_done_callback(lambda _: pbar.update(1))
354+
tasks.append(task)
349355
responses = [future.result() for future in tasks]
350356
if is_baseline_model:
351357
evaluation_run_config.dataset = df.assign(baseline_model_response=responses)
@@ -384,13 +390,14 @@ def _generate_response_from_custom_model_fn(
384390
constants.Dataset.COMPLETED_PROMPT_COLUMN
385391
in evaluation_run_config.dataset.columns
386392
):
387-
with futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
388-
for _, row in df.iterrows():
389-
tasks.append(
390-
executor.submit(
393+
with tqdm(total=len(df)) as pbar:
394+
with futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
395+
for _, row in df.iterrows():
396+
task = executor.submit(
391397
model_fn, row[constants.Dataset.COMPLETED_PROMPT_COLUMN]
392398
)
393-
)
399+
task.add_done_callback(lambda _: pbar.update(1))
400+
tasks.append(task)
394401
else:
395402
content_column_name = evaluation_run_config.column_map[
396403
constants.Dataset.CONTENT_COLUMN

0 commit comments

Comments
 (0)