From: Kyle Bowman Date: Wed, 20 Mar 2024 02:29:28 +0000 (-0400) Subject: Cleaned up _ArgSpecs implementation by subclassing dict. X-Git-Tag: v0.1.0~6 X-Git-Url: https://git.rocketbowman.com/?a=commitdiff_plain;h=c61b94cf62bc773e72659af62a24981c2582a627;p=proto.git Cleaned up _ArgSpecs implementation by subclassing dict. --- diff --git a/src/proto/utils.py b/src/proto/utils.py index 0d87284..5622c13 100644 --- a/src/proto/utils.py +++ b/src/proto/utils.py @@ -36,43 +36,40 @@ def get_types(fn: Callable)->dict[str]: types[name] = prm.annotation return types -class _ArgSpec: +class _ArgSpec(dict): """ _ArgSpec contains key-value pairs used for parser.add_argument(). """ def __init__(self, prm: inspect.Parameter): - self.prm = prm - self.argspecs = dict() - self.argspecs.update(self._parse_default()) - self.argspecs.update(self._parse_type()) - self.argname = self.argspecs.pop('argname') + super().__init__() + self.update(self._parse_default(prm)) + self.update(self._parse_type(prm)) - def _parse_default(self): + def _parse_default(self, prm): # ASSUME: If a function specifies a default argument, we tell the parser to consider it optional. dct = {} - if self.prm.default != self.prm.empty: + if prm.default != prm.empty: # NOTE: Argparse requires optional args to start with '-' or '--'. - # Doc reference: https://docs.python.org/3.11/library/argparse.html#id5 - dct['argname'] = "--" + self.prm.name + dct['argname'] = "--" + prm.name dct['required'] = False - dct['default'] = self.prm.default + dct['default'] = prm.default else: - dct['argname'] = self.prm.name + dct['argname'] = prm.name return dct - def _parse_type(self): + def _parse_type(self, prm): # NOTE: If you don't specify type in add_argument(), it will be parsed as a string. # Use get_argspecs() to add type-specific information to the arg_spec. - if isinstance(self.prm.annotation, type): # Basic types - return get_argspecs(self.prm.annotation(), self.argspecs) - elif hasattr(self.prm.annotation, '__args__'): # Unions + if isinstance(prm.annotation, type): # Basic types + return get_argspecs(prm.annotation()) + elif hasattr(prm.annotation, '__args__'): # Unions # ASSUME: Order of types in signatures indicate order of preference. - for type_ in self.prm.annotation.__args__: + for type_ in prm.annotation.__args__: try: - return get_argspecs(type_(), self.argspecs) + return get_argspecs(type_()) except TypeError as e: raise e else: - raise TypeError(f"Cannot instantiate. Check the type of {self.prm.annotation}") + raise TypeError(f"Cannot instantiate. Check the type of {prm.annotation}") def get_parser(fn: Callable)->argparse.ArgumentParser: """ Returns an argparse.ArgumentParser based on the function's signature. """ @@ -82,20 +79,18 @@ def get_parser(fn: Callable)->argparse.ArgumentParser: # ASSUME: It's a Pythonic standard to use self, but it's convention, not rule. Beware. if prm.name == "self" or prm.name == "cls": continue - _argspec = _ArgSpec(prm) - argspec = _argspec.argspecs - argname = _argspec.argname + argspec = _ArgSpec(prm) + argname = argspec.pop('argname') parser.add_argument(argname, **argspec) return parser @singledispatch -def get_argspecs(annotation: type, argspecs: dict)->dict: +def get_argspecs(annotation: type)->dict: """ Creates a partial argspec dictionary from a parameter annotation. """ - return argspecs + return {} @get_argspecs.register -def scalar_argspecs(annotation: Union[int, float, str], argspecs)->dict: +def scalar_argspecs(annotation: Union[int, float, str])->dict: """ Implements get_argspecs for integers, floats, and strings. """ - argspecs['type'] = type(annotation) - return argspecs + return {'type': type(annotation)} \ No newline at end of file