From 0c89cdcab741abdf6bef140ce40ad9e79b8a48f0 Mon Sep 17 00:00:00 2001 From: Kyle Bowman Date: Sat, 16 Mar 2024 16:25:35 -0400 Subject: [PATCH] Add handling for Union types. --- develop.md | 3 +- src/proto/utils.py | 23 ++++++++------- tests/dummy.py | 32 +++++++++++++++++++++ tests/test_proto.py | 8 ++++-- tests/test_utils.py | 70 +++++++++++++++++++-------------------------- 5 files changed, 82 insertions(+), 54 deletions(-) create mode 100644 tests/dummy.py diff --git a/develop.md b/develop.md index d2a2262..357a7c5 100644 --- a/develop.md +++ b/develop.md @@ -94,11 +94,12 @@ elif hasattr(prm.annotation, 'is_protocol' and prm.annotation._is_protocol): You can use `prm.annotation.__abstractmethods__` to determine what methods need to be defined. -### Parser definition vs Runtime arguments +### Parser definition vs Runtime arguments - Union Types * If you look through Union and define the parser based on the first type... * But then at runtime, the CLI looks like the second one... * You can't go back and redefine the parser. +* This might be a problem with Union types. * Possible solution - define a type placeholder for Union. * The type placeholder could itself dispatch to runtime parse/validators? diff --git a/src/proto/utils.py b/src/proto/utils.py index fe58b23..cf527f1 100644 --- a/src/proto/utils.py +++ b/src/proto/utils.py @@ -61,20 +61,21 @@ def get_parser(fn: Callable)->argparse.ArgumentParser: # NOTE: If you don't specify type in add_argument(), it will be parsed as a string. # Use get_argspecs() to add type-specific information to the arg_spec. - arg_specs | get_argspecs(prm.annotation(), arg_specs) - + if isinstance(prm.annotation, type): # Basic types + arg_specs | get_argspecs(prm.annotation(), arg_specs) + elif hasattr(prm.annotation, '__args__'): # Unions + # ASSUME: Order of types in signatures indicate order of preference. + for type_ in prm.annotation.__args__: + try: + arg_specs | get_argspecs(type_(), arg_specs) + break + except TypeError as e: + raise e + else: + raise TypeError(f"Cannot instantiate. Check the type of {prm.annotation}") parser.add_argument(argname, **arg_specs) return parser -# NOTE: When a single dispatch function is invoked, the the first arg is inspected. -# Based on the type of the first argument, a corresponding implementation is dispatched. -# Use @foo.register to register an implementation to the single dispatch function foo. -# The following two hacks are used throughout the get_argspecs implementations: -# HACK: type(annotation) == type. But type(annotation()) == str | int | whatever. -# This hack works because of the following hack. -# HACK: The 'type' type is callable. It behaves like a constructor for it's type. -# For example: `type(42)('36')` creates an integer 36. -# It works consistently, but I haven't seen it as defined/supported behavior. @singledispatch def get_argspecs(annotation: type, arg_specs: dict)->dict: """ Creates a partial argspec dictionary from a parameter annotation. """ diff --git a/tests/dummy.py b/tests/dummy.py new file mode 100644 index 0000000..0905623 --- /dev/null +++ b/tests/dummy.py @@ -0,0 +1,32 @@ +""" +This module contains dummy functions that are used by the test suite. +""" + +from typing import Optional + +class DummyCallable: + + def __call__(self, string: str = "default", num: int=42): + return f"String: {string} \n Integer: {num} \n" + +class DummyClass: + + def dummy_method(self, string: str = "default", num: int=42): + return f"String: {string} \n Integer: {num} \n" + + @classmethod + def dummy_classmethod(cls, string: str = "default", num: int=42): + return f"String: {string} \n Integer: {num} \n" + +def dummy_fn_no_signature(string, num)->str: + return f"String: {string} \n Integer: {num} \n" + +def dummy_fn_typed(string: str, num: int)->str: + return f"String: {string} \n Integer: {num} \n" + +def dummy_fn_optional(string: Optional[str] = None, num: Optional[int] = 42)->str: + return f"String: {string} \n Integer: {num} \n" + +# NOTE: The dummy_fn_full function is the prototypical "happy path" for signagures. +def dummy_fn_full(string: str="default", num: int = 42)->str: + return f"String: {string} \n Integer: {num} \n" diff --git a/tests/test_proto.py b/tests/test_proto.py index fe16c7f..826cb9b 100644 --- a/tests/test_proto.py +++ b/tests/test_proto.py @@ -1,17 +1,21 @@ +from abc import abstractmethod from typing import Protocol import pytest from proto import command, Command class Stringable(Protocol): + @abstractmethod def __str__(self)->str: ... -def echo_fn(arg: Stringable)->str: +#def echo_fn(arg: Stringable)->str: +def echo_fn(arg: str)->str: return str(arg) @command -def echo(arg: Stringable)->str: +#def echo(arg: Stringable)->str: +def echo(arg: str)->str: return echo_fn(arg) def test_attributes(): diff --git a/tests/test_utils.py b/tests/test_utils.py index a7c4054..41f0b03 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,43 +1,11 @@ import inspect +import os import sys -from typing import Optional +from typing import Optional import pytest -from proto.utils import get_defaults, get_types, get_parser - -############## -# Begin Data # -- signature testing -############## -# TODO: (?) Move these dummy functions to their own module within test dir. -class DummyCallable: - - def __call__(self, string: str = "default", num: int=42): - return f"String: {string} \n Integer: {num} \n" - -class DummyClass: - - def dummy_method(self, string: str = "default", num: int=42): - return f"String: {string} \n Integer: {num} \n" - - @classmethod - def dummy_classmethod(cls, string: str = "default", num: int=42): - return f"String: {string} \n Integer: {num} \n" -def dummy_fn_no_signature(string, num)->str: - return f"String: {string} \n Integer: {num} \n" - -def dummy_fn_typed(string: str, num: int)->str: - return f"String: {string} \n Integer: {num} \n" - -def dummy_fn_optional(string: Optional[str], num: Optional[int] = 42)->str: - return f"String: {string} \n Integer: {num} \n" - -# NOTE: The dummy_fn_full function is the prototypical "happy path" for signagures. -def dummy_fn_full(string: str="default", num: int = 42)->str: - return f"String: {string} \n Integer: {num} \n" - -############### -# Begin Tests # -############### +from dummy import * +from proto.utils import get_defaults, get_types, get_parser def test_get_default_equivalence(): """ Ensures that defaults are treated the same amongst Callables. """ @@ -93,9 +61,6 @@ def test_validating_optional(): _, kwargs = get_defaults(dummy_fn_optional) assert isinstance(kwargs['num'],types['num']) -# NOTE: The parser.parse_args() method always returns keys and values as strings. -# You must cast values yourself to compare. -# NOTE: sys.argv[0] is the program name and is not needed. def test_get_parser_defaults(): """ If a fn default is specified, use keyword syntax (optional). """ string="not default" @@ -137,9 +102,34 @@ def test_get_parser_types_scalar(): assert args.string == string assert args.num == int(num) +def test_get_parser_types_union(): + parser=get_parser(dummy_fn_optional) + args=parser.parse_args(args=['--string','yay']) + assert args.string == "yay" + +def test_get_parser_types_union_defaults(): + parser=get_parser(dummy_fn_optional) + args=parser.parse_args(args=[]) + assert args.num == int(42) + assert args.string == None if __name__ == "__main__": import inspect sig = inspect.signature(dummy_fn_full) p = sig.parameters['string'] - p2 = sig.parameters['num'] \ No newline at end of file + p2 = sig.parameters['num'] + def dummy_fn_file(filename: Optional[os.PathLike] = None): + if filename is None: + contents = sys.stdin.read() + return contents + else: + return str(filename) + sig2=inspect.signature(dummy_fn_file) + p3 = sig2.parameters['filename'] + + from test_proto import Stringable + def echo(arg: Stringable): + return str(arg) + sig3 = inspect.signature(echo) + p4 = sig3.parameters['arg'] + # p3.annotation.__args__ = (, ) \ No newline at end of file -- 2.39.5