This commit is contained in:
HakurrrPunk 2023-07-15 14:42:05 +10:00 committed by GitHub
commit 5fbb653f56
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,12 +1,11 @@
import argparse
import logging
import os
import pkgutil
import sys
from urllib.parse import urlsplit
import apache_beam as beam
from apache_beam.options.pipeline_options import PipelineOptions
from .apache_beam.options.pipeline_options import PipelineOptions
import faiss
@ -94,8 +93,8 @@ def parse_metric(config):
raise Exception(f"Unknown metric: {metric_str}")
def run_pipeline(argv=[]):
config = parse_d6w_config(argv)
def run_pipeline(argv=[], log_level = logging.INFO):
config = parse_d6w_config(argv=None)
argv_with_extras = argv
if config["gpu"]:
argv_with_extras.extend(["--experiments", "use_runner_v2"])
@ -108,7 +107,7 @@ def run_pipeline(argv=[]):
"gcr.io/twttr-recos-ml-prod/dataflow-gpu/beam2_39_0_py3_7",
]
)
logging.getLogger().setLevel(log_level)
options = PipelineOptions(argv_with_extras)
output_bucket_name = urlsplit(config["output_location"]).netloc
@ -228,5 +227,10 @@ class MergeAndBuildIndex(beam.CombineFn):
if __name__ == "__main__":
logging.getLogger().setLevel(logging.INFO)
run_pipeline(sys.argv)
parser = argparse.ArgumentParser()
parser.add_argument("--log_level", dest="log_level", default="INFO", help="Logging level")
args, pipeline_args = parser.parse_known_args()
logging.getLogger().setLevel(args.log_level)
run_pipeline(pipeline_args, log_level=args.log_level)