From 19530cbc951b76c0e95628adcf3994129837b43b Mon Sep 17 00:00:00 2001 From: dberenbaum Date: Fri, 2 Feb 2024 19:48:10 +0000 Subject: [PATCH] debug byo training script --- dvc.lock | 28 ++++++++++-------- dvc.yaml | 10 ++++++- requirements.txt | 4 ++- sm_training.py | 77 +++++++++++++++++++++++++++++++++--------------- train.py | 62 ++++++++++++++++++++++++++++++++++++++ 5 files changed, 144 insertions(+), 37 deletions(-) create mode 100644 train.py diff --git a/dvc.lock b/dvc.lock index 2748199..90ae923 100644 --- a/dvc.lock +++ b/dvc.lock @@ -2,9 +2,9 @@ schema: '2.0' stages: prepare: cmd: - - rm -f bank-additional.zip - - wget + - curl https://sagemaker-sample-data-us-west-2.s3-us-west-2.amazonaws.com/autopilot/direct_marketing/bank-additional.zip + -o bank-additional.zip - python sm_prepare.py --bucket dave-sandbox --prefix sagemaker/DEMO-xgboost-dm deps: - path: @@ -41,17 +41,17 @@ stages: outs: - path: s3://dave-sandbox/sagemaker/DEMO-xgboost-dm/test hash: md5 - md5: 2c231c2500d0475fbbd28afdf10964aa.dir + md5: 4395b17f71bb1bd5f6e9d0e22e960e81.dir size: 506467 nfiles: 2 - path: s3://dave-sandbox/sagemaker/DEMO-xgboost-dm/train hash: md5 - md5: 682392fa257794c9fd432693002ce2e4.dir + md5: e2e0dcf49e644f580342279f550ad574.dir size: 3545009 nfiles: 1 - path: s3://dave-sandbox/sagemaker/DEMO-xgboost-dm/validation hash: md5 - md5: 7ddcae9eeea71d4423832ed7f251a286.dir + md5: 5c3e7e009a90314d9a9531a5b68c45d5.dir size: 1012968 nfiles: 1 training: @@ -61,21 +61,25 @@ stages: deps: - path: s3://dave-sandbox/sagemaker/DEMO-xgboost-dm/train hash: md5 - md5: 682392fa257794c9fd432693002ce2e4.dir + md5: e2e0dcf49e644f580342279f550ad574.dir size: 3545009 nfiles: 1 - path: s3://dave-sandbox/sagemaker/DEMO-xgboost-dm/validation hash: md5 - md5: 7ddcae9eeea71d4423832ed7f251a286.dir + md5: 5c3e7e009a90314d9a9531a5b68c45d5.dir size: 1012968 nfiles: 1 - path: sm_training.py hash: md5 - md5: 73ea3430208898e6ce2879a97c511464 - size: 2119 + md5: 3c3804455d5bf355677fbc8630659084 + size: 3427 + - path: train.py + hash: md5 + md5: d06e210befac037f4216ea326d9d3e81 + size: 2211 outs: - path: s3://dave-sandbox/sagemaker/DEMO-xgboost-dm/output hash: md5 - md5: a1ce699466f31bf065ed1382cfd80c5c.dir - size: 416993 - nfiles: 7 + md5: ad60d9f3bafa5f35403ea433c097ea5d.dir + size: 393592 + nfiles: 6 diff --git a/dvc.yaml b/dvc.yaml index 76d5cc3..540a6dd 100644 --- a/dvc.yaml +++ b/dvc.yaml @@ -1,7 +1,9 @@ stages: prepare: cmd: - - wget https://sagemaker-sample-data-us-west-2.s3-us-west-2.amazonaws.com/autopilot/direct_marketing/bank-additional.zip -O bank-additional.zip + - curl + https://sagemaker-sample-data-us-west-2.s3-us-west-2.amazonaws.com/autopilot/direct_marketing/bank-additional.zip + -o bank-additional.zip - python sm_prepare.py --bucket ${bucket} --prefix ${prefix} deps: - sm_prepare.py @@ -26,8 +28,14 @@ stages: cmd: python sm_training.py --bucket ${bucket} --prefix ${prefix} ${train} deps: - sm_training.py + - train.py - s3://${bucket}/${prefix}/train - s3://${bucket}/${prefix}/validation outs: - s3://${bucket}/${prefix}/output: cache: false +plots: +- dvclive/plots/metrics: + x: step +params: +- dvclive/params.yaml diff --git a/requirements.txt b/requirements.txt index f778f1d..4e9f109 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,5 @@ dvc[s3] pandas -sagemaker \ No newline at end of file +sagemaker +xgboost +git+https://github.com/iterative/dvclive.git@no-git \ No newline at end of file diff --git a/sm_training.py b/sm_training.py index 8ab54c7..f49c841 100644 --- a/sm_training.py +++ b/sm_training.py @@ -1,7 +1,10 @@ import argparse +import os + import boto3 import sagemaker from sagemaker import get_execution_role +from sagemaker.xgboost.estimator import XGBoost def main(): @@ -14,7 +17,7 @@ def main(): parser.add_argument("--min_child_weight", type=int) parser.add_argument("--subsample", type=float) parser.add_argument("--silent", type=int) - parser.add_argument("--objective") + parser.add_argument("--objective", type=str) parser.add_argument("--num_round", type=int) args = parser.parse_args() @@ -23,32 +26,60 @@ def main(): train_path = f"s3://{bucket}/{prefix}/train" validation_path = f"s3://{bucket}/{prefix}/validation" - - - container = sagemaker.image_uris.retrieve(region=boto3.Session().region_name, framework='xgboost', version='latest') - + + #container = sagemaker.image_uris.retrieve(region=boto3.Session().region_name, framework='xgboost', version='latest') s3_input_train = sagemaker.inputs.TrainingInput(s3_data=train_path.format(bucket, prefix), content_type='csv') s3_input_validation = sagemaker.inputs.TrainingInput(s3_data=validation_path.format(bucket, prefix), content_type='csv') + + #sess = sagemaker.Session() + + env = {name: value for name, value in os.environ.items() if name.startswith("DVC")} + print(env) + + #xgb = sagemaker.estimator.Estimator(container, + # get_execution_role(), + # instance_count=1, + # instance_type='ml.m4.xlarge', + # output_path='s3://{}/{}/output'.format(bucket, prefix), + # sagemaker_session=sess, + # entry_point="train.py", # include training script + # source_dir=".", # include repo path that points to requirements.txt + # env=env, # pass dvc environment variables + # ) + + #xgb.set_hyperparameters(max_depth=args.max_depth, + # eta=args.eta, + # gamma=args.gamma, + # min_child_weight=args.min_child_weight, + # subsample=args.subsample, + # silent=args.silent, + # objective=args.objective, + # num_round=args.num_round) - - sess = sagemaker.Session() - - xgb = sagemaker.estimator.Estimator(container, - get_execution_role(), - instance_count=1, - instance_type='ml.m4.xlarge', - output_path='s3://{}/{}/output'.format(bucket, prefix), - sagemaker_session=sess) - xgb.set_hyperparameters(max_depth=args.max_depth, - eta=args.eta, - gamma=args.gamma, - min_child_weight=args.min_child_weight, - subsample=args.subsample, - silent=args.silent, - objective=args.objective, - num_round=args.num_round) - + hyperparameters = { + "max_depth": args.max_depth, + "eta": args.eta, + "gamma": args.gamma, + "min_child_weight": args.min_child_weight, + "subsample": args.subsample, + "silent": args.silent, + #"objective": args.objective, + "num_round": args.num_round, + } + + xgb = XGBoost( + role=get_execution_role(), + instance_count=1, + instance_type="ml.m4.xlarge", + framework_version="latest", + output_path='s3://{}/{}/output'.format(bucket, prefix), + hyperparameters=hyperparameters, + entry_point="train.py", # include training script + source_dir=".", # include repo path that points to requirements.txt + env=env, # pass dvc environment variables + ) + xgb.fit({'train': s3_input_train, 'validation': s3_input_validation}) diff --git a/train.py b/train.py new file mode 100644 index 0000000..791e5df --- /dev/null +++ b/train.py @@ -0,0 +1,62 @@ +import argparse +import logging +import pickle +import os + +import xgboost as xgb +from dvclive import Live +from dvclive.xgb import DVCLiveCallback + + +if __name__ == '__main__': + logging.info("starting train script.") + parser = argparse.ArgumentParser() + + # Hyperparameters are described here + parser.add_argument("--max_depth", type=int) + parser.add_argument("--eta", type=float) + parser.add_argument("--gamma", type=int) + parser.add_argument("--min_child_weight", type=int) + parser.add_argument("--subsample", type=float) + parser.add_argument("--silent", type=int) + parser.add_argument("--objective", type=str, default="binary:logistic") + parser.add_argument("--num_round", type=int) + + # SageMaker specific arguments. Defaults are set in the environment variables. + parser.add_argument('--model_dir', type=str, default=os.environ.get('SM_MODEL_DIR')) + parser.add_argument('--train', type=str, default=os.environ.get('SM_CHANNEL_TRAIN')) + parser.add_argument('--validation', type=str, default=os.environ.get('SM_CHANNEL_VALIDATION')) + + args = parser.parse_args() + + train_hp = { + 'max_depth': args.max_depth, + 'eta': args.eta, + 'gamma': args.gamma, + 'min_child_weight': args.min_child_weight, + 'subsample': args.subsample, + 'silent': args.silent, + #'objective': args.objective, + } + + dtrain = xgb.DMatrix(args.train + "?format=csv&label_column=0") + dval = xgb.DMatrix(args.validation + "?format=csv&label_column=0") + watchlist = [(dtrain, 'train'), (dval, 'validation')] if dval is not None else [(dtrain, 'train')] + + with Live(os.path.join(args.model_dir + "/dvclive")) as live: + live.log_param("cwd", os.getcwd()) + + callbacks = [DVCLiveCallback()] + + bst = xgb.train( + params=train_hp, + dtrain=dtrain, + evals=watchlist, + num_boost_round=args.num_round, + callbacks=callbacks + ) + + # Save the model to the location specified by ``model_dir`` + model_location = args.model_dir + '/xgboost-model' + pickle.dump(bst, open(model_location, 'wb')) + logging.info("Stored trained model at {}".format(model_location)) \ No newline at end of file