-
Notifications
You must be signed in to change notification settings - Fork 287
[WIP] Fix redundant nvml init by cached handles #56
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,12 +11,16 @@ | |
from .core import GPUStatCollection | ||
|
||
|
||
def print_gpustat(json=False, debug=False, **kwargs): | ||
def print_gpustat(json=False, debug=False, gpu_stat=None, **kwargs): | ||
''' | ||
Display the GPU query results into standard output. | ||
''' | ||
|
||
try: | ||
gpu_stats = GPUStatCollection.new_query() | ||
if gpu_stat is None: | ||
gpu_stat = GPUStatCollection() | ||
else: | ||
gpu_stat.update() | ||
except Exception as e: | ||
sys.stderr.write('Error on querying NVIDIA devices.' | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The error catching workflow is complicated for me, can we just throw the raw error messages instead of the current one. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For example, if an error happens during the loop, we can simply throw it out. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The error catching requires all the NVML API calls to be wrapped with a long |
||
' Use --debug flag for details\n') | ||
|
@@ -32,9 +36,10 @@ def print_gpustat(json=False, debug=False, **kwargs): | |
sys.exit(1) | ||
|
||
if json: | ||
gpu_stats.print_json(sys.stdout) | ||
gpu_stat.print_json(sys.stdout) | ||
else: | ||
gpu_stats.print_formatted(sys.stdout, **kwargs) | ||
gpu_stat.print_formatted(sys.stdout, **kwargs) | ||
return gpu_stat | ||
|
||
|
||
def main(*argv): | ||
|
@@ -88,6 +93,8 @@ def main(*argv): | |
) | ||
args = parser.parse_args(argv[1:]) | ||
|
||
cached_stat = None | ||
|
||
if args.interval is None: # with default value | ||
args.interval = 1.0 | ||
if args.interval > 0: | ||
|
@@ -102,7 +109,7 @@ def main(*argv): | |
try: | ||
query_start = time.time() | ||
with term.location(0, 0): | ||
print_gpustat(eol_char=term.clear_eol + '\n', **vars(args)) # noqa | ||
cached_stat = print_gpustat(gpu_stat=cached_stat, eol_char=term.clear_eol + '\n', **vars(args)) # noqa | ||
print(term.clear_eos, end='') | ||
query_duration = time.time() - query_start | ||
sleep_duration = args.interval - query_duration | ||
|
@@ -111,7 +118,7 @@ def main(*argv): | |
except KeyboardInterrupt: | ||
exit(0) | ||
else: | ||
print_gpustat(**vars(args)) | ||
print_gpustat(gpu_stat=cached_stat, **vars(args)) | ||
|
||
|
||
if __name__ == '__main__': | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -250,19 +250,38 @@ def jsonify(self): | |
|
||
class GPUStatCollection(object): | ||
|
||
def __init__(self, gpu_list): | ||
def __init__(self, gpu_list=[]): | ||
"""The initialization argument gpu_list is remained to support | ||
existing APIs""" | ||
self.gpus = gpu_list | ||
|
||
# attach additional system information | ||
self.hostname = platform.node() | ||
N.nvmlInit() | ||
device_count = N.nvmlDeviceGetCount() | ||
self.handles = [N.nvmlDeviceGetHandleByIndex(idx) | ||
for idx in range(device_count)] | ||
|
||
self.update() | ||
|
||
def __del__(self): | ||
# sometimes delayed gc causes problem, just attempt to release | ||
# NVML resources | ||
try: | ||
N.nvmlShutdown() | ||
except Exception: | ||
pass | ||
|
||
def update(self): | ||
self._update_host() | ||
self._update_gpu() | ||
|
||
def _update_host(self): | ||
"""Update additional host information""" | ||
self.query_time = datetime.now() | ||
self.hostname = platform.node() | ||
|
||
@staticmethod | ||
def new_query(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should not remove this static method this is already being used (as a public API, sort of) in other applications. Since you are lengthening the lifecycle of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I thought the wrapper function It's ok to retain the legacy API, but I prefer a factory function like @classmethod
def new_query(cls):
return cls() |
||
def _update_gpu(self): | ||
"""Query the information of all the GPUs on local machine""" | ||
|
||
N.nvmlInit() | ||
|
||
def get_gpu_info(handle): | ||
"""Get one GPU information specified by nvml handle""" | ||
|
||
|
@@ -365,17 +384,12 @@ def _decode(b): | |
return gpu_info | ||
|
||
# 1. get the list of gpu and status | ||
gpu_list = [] | ||
device_count = N.nvmlDeviceGetCount() | ||
self.gpus = [] | ||
|
||
for index in range(device_count): | ||
handle = N.nvmlDeviceGetHandleByIndex(index) | ||
for handle in self.handles: | ||
gpu_info = get_gpu_info(handle) | ||
gpu_stat = GPUStat(gpu_info) | ||
gpu_list.append(gpu_stat) | ||
|
||
N.nvmlShutdown() | ||
return GPUStatCollection(gpu_list) | ||
self.gpus.append(gpu_stat) | ||
|
||
def __len__(self): | ||
return len(self.gpus) | ||
|
@@ -471,7 +485,7 @@ def date_handler(obj): | |
|
||
def new_query(): | ||
''' | ||
Obtain a new GPUStatCollection instance by querying nvidia-smi | ||
Obtain a new GPUStatCollection instance by querying nvml | ||
to get the list of GPUs and running process information. | ||
''' | ||
return GPUStatCollection.new_query() | ||
return GPUStatCollection() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It makes less sense to take
gpu_stat
as an argument for this function. I was thinking of a style like (i) we first create a GPUStatCollection instance here and (ii) then print outputs using it. As the same instance should be reused, we can consider running a wait loop inside this function rather than calling it multiple times.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about making this function a generator, so we can keep the
GPUStatCollection
object in the closure.