]> git.rocketbowman.com Git - proto.git/commitdiff
Create _ArgSpecs class.
authorKyle Bowman <kylebowman14@gmail.com>
Wed, 20 Mar 2024 02:15:57 +0000 (22:15 -0400)
committerKyle Bowman <kylebowman14@gmail.com>
Wed, 20 Mar 2024 02:15:57 +0000 (22:15 -0400)
src/proto/utils.py

index cf527f1139b5c60525b3371fae7597701249a9b9..0d87284daddea4d43737e1cecd91415bbdc821e3 100644 (file)
@@ -36,54 +36,66 @@ def get_types(fn: Callable)->dict[str]:
             types[name] = prm.annotation
     return types
 
-def get_parser(fn: Callable)->argparse.ArgumentParser:
-    """ Returns an argparse.ArgumentParser based on the function's signature. """
-    sig = inspect.signature(fn)
-    parser = argparse.ArgumentParser() 
-    
-    for prm in sig.parameters.values():
-        argname = None
-        arg_specs = dict()
+class _ArgSpec:
+    """ _ArgSpec contains key-value pairs used for parser.add_argument(). """
 
-        # ASSUME: It's a Pythonic standard to use self, but it's convention, not rule. Beware.
-        if prm.name == "self" or prm.name == "cls":
-            continue
+    def __init__(self, prm: inspect.Parameter):
+        self.prm        = prm
+        self.argspecs   = dict()
+        self.argspecs.update(self._parse_default())
+        self.argspecs.update(self._parse_type())
+        self.argname = self.argspecs.pop('argname') 
 
+    def _parse_default(self):
         # ASSUME: If a function specifies a default argument, we tell the parser to consider it optional.
-        if prm.default != prm.empty:
+        dct = {}
+        if self.prm.default != self.prm.empty:
             # NOTE: Argparse requires optional args to start with '-' or '--'.
             # Doc reference: https://docs.python.org/3.11/library/argparse.html#id5
-            argname = "--" + prm.name
-            arg_specs['required'] = False
-            arg_specs['default']=prm.default
+            dct['argname']  = "--" + self.prm.name
+            dct['required'] = False
+            dct['default']  = self.prm.default
         else: 
-            argname = prm.name
+            dct['argname']  = self.prm.name
+        return dct
 
+    def _parse_type(self):
         # NOTE: If you don't specify type in add_argument(), it will be parsed as a string.
         # Use get_argspecs() to add type-specific information to the arg_spec.
-        if isinstance(prm.annotation, type):    # Basic types
-            arg_specs | get_argspecs(prm.annotation(), arg_specs)
-        elif hasattr(prm.annotation, '__args__'): # Unions
+        if isinstance(self.prm.annotation, type):    # Basic types
+            return get_argspecs(self.prm.annotation(), self.argspecs)
+        elif hasattr(self.prm.annotation, '__args__'): # Unions
             # ASSUME: Order of types in signatures indicate order of preference.
-            for type_ in prm.annotation.__args__:
+            for type_ in self.prm.annotation.__args__:
                 try:
-                    arg_specs | get_argspecs(type_(), arg_specs)
-                    break
+                    return get_argspecs(type_(), self.argspecs)
                 except TypeError as e:
                     raise e
         else:
-            raise TypeError(f"Cannot instantiate. Check the type of {prm.annotation}")
-        parser.add_argument(argname, **arg_specs)
+            raise TypeError(f"Cannot instantiate. Check the type of {self.prm.annotation}")
+
+def get_parser(fn: Callable)->argparse.ArgumentParser:
+    """ Returns an argparse.ArgumentParser based on the function's signature. """
+    sig = inspect.signature(fn)
+    parser = argparse.ArgumentParser() 
+    for prm in sig.parameters.values():
+        # ASSUME: It's a Pythonic standard to use self, but it's convention, not rule. Beware.
+        if prm.name == "self" or prm.name == "cls":
+            continue
+        _argspec = _ArgSpec(prm)
+        argspec = _argspec.argspecs
+        argname = _argspec.argname
+        parser.add_argument(argname, **argspec)
     return parser
 
 @singledispatch
-def get_argspecs(annotation: type, arg_specs: dict)->dict:
+def get_argspecs(annotation: type, argspecs: dict)->dict:
     """ Creates a partial argspec dictionary from a parameter annotation. """
-    return arg_specs
+    return argspecs
 
 @get_argspecs.register
-def scalar_argspecs(annotation: Union[int, float, str], arg_specs)->dict:
+def scalar_argspecs(annotation: Union[int, float, str], argspecs)->dict:
     """ Implements get_argspecs for integers, floats, and strings. """
-    arg_specs['type'] = type(annotation)
-    return arg_specs
+    argspecs['type'] = type(annotation)
+    return argspecs
     
\ No newline at end of file