From 37d53f64b3cea6ca449fb8bc1d880288e050d11e Mon Sep 17 00:00:00 2001 From: Kyle Bowman Date: Mon, 20 Jan 2025 21:28:51 -0500 Subject: [PATCH] refactor: make NomList based on set builtin --- src/nom/base.py | 56 +++++++++++++++++---------------------------- src/nom/feed.py | 4 ++-- src/nom/main.py | 7 +++--- tests/test_cli.py | 10 ++++++++ tests/test_entry.py | 14 +++--------- tests/test_feed.py | 2 +- 6 files changed, 41 insertions(+), 52 deletions(-) create mode 100644 tests/test_cli.py diff --git a/src/nom/base.py b/src/nom/base.py index 8b79976..fbd4caa 100644 --- a/src/nom/base.py +++ b/src/nom/base.py @@ -3,7 +3,7 @@ from csv import DictReader, DictWriter, excel_tab from copy import copy from pathlib import Path from pydantic import BaseModel -from typing import Callable +from typing import Callable, Iterable, Optional from nom.utils import NomError @@ -23,7 +23,7 @@ class NomListItem(BaseModel): # TODO: What if there's a pipe in one of the fields? def to_str(self, delimiter: str ='|'): - return delimiter.join([v for v in self.__dict__.values()]) + return delimiter.join([str(v) for v in self.__dict__.values()]) def to_dict(self): return vars(self) @@ -33,31 +33,20 @@ class NomListItem(BaseModel): return cls(**dct) -class NomList: +class NomList(set): - def __init__(self, items=set(), delimiter: str="|"): - self.delimiter=delimiter - self.items : set[NomListItem] = items + def __init__(self, elements: Optional[Iterable[NomListItem]]=None): + if not elements: + super().__init__() + else: + super().__init__(elements) def __add__(self, other): - dct = copy(vars(self)) - dct['items'] = self.items.union(other.items) - return self.__class__(**dct) + return self.__class__(self.union(other)) - def __contains__(self, value): - return value in self.items - - def __eq__(self, other): - return self.items == other.items - - def __iter__(self): - return self.items.__iter__() - - def __len__(self): - return len(self.items) - - def merge(self, other): - self.items.update(other.items) + def select(self, predicate: Predicate): + items = {item for item in self if predicate(item)} + return self.__class__(items) def to_stdout(self): for item in self.items: @@ -82,28 +71,25 @@ class NomList: for row in reader: item = nlitem.from_dict(row) items.append(item) - return cls(items=set(items), delimiter=delimiter) + return cls(items) - def to_csv(self, file: Path): - if not self.items: + def to_csv(self, file: Path, delimiter="|"): + if not self: raise NomError("There are no entries to write.") - fieldnames=next(iter(self.items)).get_fieldnames() + fieldnames=next(iter(self)).get_fieldnames() dialect = excel_tab - dialect.delimiter=self.delimiter + dialect.delimiter=delimiter with open(file, "w") as f: writer = DictWriter(f, fieldnames=fieldnames, dialect=dialect) writer.writeheader() - for item in self.items: + for item in self: writer.writerow(item.to_dict()) def to_stdout(self): - if not self.items: + if not self: raise NomError("There are no entries to write.") - for item in self.items: + for item in self: print(item.to_str()) - -def filter(nlist: NomList, predicate: Predicate): - items = {item for item in nlist.items if predicate(item)} - return nlist.__class__(items, delimiter=nlist.delimiter) \ No newline at end of file + \ No newline at end of file diff --git a/src/nom/feed.py b/src/nom/feed.py index 7868ba2..446af4f 100644 --- a/src/nom/feed.py +++ b/src/nom/feed.py @@ -30,7 +30,7 @@ class Feed: #viewed=False, summary="no summary") items.append(entry) - return EntryList(items=items) + return EntryList(items) class FeedListItem(NomListItem): @@ -68,5 +68,5 @@ class FeedList(NomList): return cls(file.name, urls) def fetch_feeds(self, save_dir: Path): - for flitem in self.items: + for flitem in self: flitem.fetch_feed(save_dir) \ No newline at end of file diff --git a/src/nom/main.py b/src/nom/main.py index d69a441..26850f7 100644 --- a/src/nom/main.py +++ b/src/nom/main.py @@ -1,4 +1,5 @@ from pathlib import Path +import sys from nom.utils import url2filename, NomError from nom.feed import Feed, FeedList @@ -10,16 +11,16 @@ FEED_CACHE=Path.home() / ".cache" / "nom" / "feeds" FEED_LIST=Path.home() / ".local" / "share" / "nom" / "feedlist" / "default" # TODO: Flesh out CLI. -def main(): +def main(args=['nom'].append(sys.argv)): parser = cli() - args = parser.parse_args() + args = parser.parse_args(args=args) # Direct Logic feedlist=FeedList.from_csv(FEED_LIST) if args.command == "entry" and args.entry_command == "show": elist=EntryList() for flitem in feedlist: - elist.merge(flitem.to_feed().to_entrylist()) + elist += flitem.to_feed().to_entrylist() elist.to_stdout() elif args.command == "feed" and args.feed_command == "update": feedlist.fetch_feeds(FEED_CACHE) diff --git a/tests/test_cli.py b/tests/test_cli.py new file mode 100644 index 0000000..7c1c086 --- /dev/null +++ b/tests/test_cli.py @@ -0,0 +1,10 @@ +from nom.main import main + + +def test_nom_entry_show(): + main(args='entry show'.split(' ')) + assert True + +def test_nom_feed_update(): + main(args='feed update'.split(' ')) + assert True \ No newline at end of file diff --git a/tests/test_entry.py b/tests/test_entry.py index 97f8f5d..df215b3 100644 --- a/tests/test_entry.py +++ b/tests/test_entry.py @@ -4,7 +4,6 @@ from copy import copy import pytest from nom.entry import EntryList, EntryListItem -from nom.filter import is_viewed from test_feed import feedlist @@ -21,7 +20,7 @@ def elist_multi(): @pytest.fixture def elist_item(elist_single): - return next(iter(elist_single.items)) + return next(iter(elist_single)) def test_elist_constructors(elist_single): @@ -37,18 +36,11 @@ def test_eli_to_from_dict_idempotency(elist_item): remade = elist_item.from_dict(elist_item.to_dict()) assert remade == elist_item -def test_elist_merge(elist_multi, elist_single): - original_length = len(elist_multi) - elist_multi.merge(elist_single) - assert len(elist_multi) == original_length + 1 - def test_elist_addition(elist_multi, elist_single): sum_ = elist_multi + elist_single assert len(sum_) == len(elist_multi) + len(elist_single) assert isinstance(sum_,EntryList) -def test_elist_filter(elist_multi): - #viewed = filter(is_viewed, elist_multi) - from nom.base import filter - viewed=filter(elist_multi, lambda e: e.viewed) +def test_elist_select(elist_multi): + viewed = elist_multi.select(lambda e: e.viewed) assert len(viewed) < len(elist_multi) diff --git a/tests/test_feed.py b/tests/test_feed.py index f53f01c..d92fbe3 100644 --- a/tests/test_feed.py +++ b/tests/test_feed.py @@ -16,5 +16,5 @@ def test_flist_from_csv(feedlist): def test_to_entrylist(feedlist): elist = EntryList() for flitem in feedlist: - elist.merge(flitem.to_feed().to_entrylist()) + elist += flitem.to_feed().to_entrylist() assert len(elist) == 5 \ No newline at end of file -- 2.39.5