]> git.rocketbowman.com Git - proto.git/commitdiff
Cleaned up _ArgSpecs implementation by subclassing dict.
authorKyle Bowman <kylebowman14@gmail.com>
Wed, 20 Mar 2024 02:29:28 +0000 (22:29 -0400)
committerKyle Bowman <kylebowman14@gmail.com>
Wed, 20 Mar 2024 02:29:28 +0000 (22:29 -0400)
src/proto/utils.py

index 0d87284daddea4d43737e1cecd91415bbdc821e3..5622c1339a2b6f67553e0f60d676eddb54866414 100644 (file)
@@ -36,43 +36,40 @@ def get_types(fn: Callable)->dict[str]:
             types[name] = prm.annotation
     return types
 
-class _ArgSpec:
+class _ArgSpec(dict):
     """ _ArgSpec contains key-value pairs used for parser.add_argument(). """
 
     def __init__(self, prm: inspect.Parameter):
-        self.prm        = prm
-        self.argspecs   = dict()
-        self.argspecs.update(self._parse_default())
-        self.argspecs.update(self._parse_type())
-        self.argname = self.argspecs.pop('argname') 
+        super().__init__()
+        self.update(self._parse_default(prm))
+        self.update(self._parse_type(prm))
 
-    def _parse_default(self):
+    def _parse_default(self, prm):
         # ASSUME: If a function specifies a default argument, we tell the parser to consider it optional.
         dct = {}
-        if self.prm.default != self.prm.empty:
+        if prm.default != prm.empty: 
             # NOTE: Argparse requires optional args to start with '-' or '--'.
-            # Doc reference: https://docs.python.org/3.11/library/argparse.html#id5
-            dct['argname']  = "--" + self.prm.name
+            dct['argname']  = "--" + prm.name
             dct['required'] = False
-            dct['default']  = self.prm.default
+            dct['default']  = prm.default
         else: 
-            dct['argname']  = self.prm.name
+            dct['argname']  = prm.name
         return dct
 
-    def _parse_type(self):
+    def _parse_type(self, prm):
         # NOTE: If you don't specify type in add_argument(), it will be parsed as a string.
         # Use get_argspecs() to add type-specific information to the arg_spec.
-        if isinstance(self.prm.annotation, type):    # Basic types
-            return get_argspecs(self.prm.annotation(), self.argspecs)
-        elif hasattr(self.prm.annotation, '__args__'): # Unions
+        if isinstance(prm.annotation, type):    # Basic types
+            return get_argspecs(prm.annotation())
+        elif hasattr(prm.annotation, '__args__'): # Unions
             # ASSUME: Order of types in signatures indicate order of preference.
-            for type_ in self.prm.annotation.__args__:
+            for type_ in prm.annotation.__args__:
                 try:
-                    return get_argspecs(type_(), self.argspecs)
+                    return get_argspecs(type_())
                 except TypeError as e:
                     raise e
         else:
-            raise TypeError(f"Cannot instantiate. Check the type of {self.prm.annotation}")
+            raise TypeError(f"Cannot instantiate. Check the type of {prm.annotation}")
 
 def get_parser(fn: Callable)->argparse.ArgumentParser:
     """ Returns an argparse.ArgumentParser based on the function's signature. """
@@ -82,20 +79,18 @@ def get_parser(fn: Callable)->argparse.ArgumentParser:
         # ASSUME: It's a Pythonic standard to use self, but it's convention, not rule. Beware.
         if prm.name == "self" or prm.name == "cls":
             continue
-        _argspec = _ArgSpec(prm)
-        argspec = _argspec.argspecs
-        argname = _argspec.argname
+        argspec = _ArgSpec(prm)
+        argname = argspec.pop('argname')
         parser.add_argument(argname, **argspec)
     return parser
 
 @singledispatch
-def get_argspecs(annotation: type, argspecs: dict)->dict:
+def get_argspecs(annotation: type)->dict:
     """ Creates a partial argspec dictionary from a parameter annotation. """
-    return argspecs
+    return {}
 
 @get_argspecs.register
-def scalar_argspecs(annotation: Union[int, float, str], argspecs)->dict:
+def scalar_argspecs(annotation: Union[int, float, str])->dict:
     """ Implements get_argspecs for integers, floats, and strings. """
-    argspecs['type'] = type(annotation)
-    return argspecs
+    return {'type': type(annotation)}
     
\ No newline at end of file