Refactor and add tests

This commit is contained in:
Pavle Portic 2021-08-06 00:07:30 +02:00
parent da0414f9ba
commit e3d0e266c9
Signed by: TheEdgeOfRage
GPG Key ID: F2AB38285780DE3D
19 changed files with 440 additions and 62 deletions

View File

@ -1,3 +1,6 @@
[mypy-aioresponses.*]
ignore_missing_imports=true
[mypy-feedparser.*]
ignore_missing_imports=true

View File

@ -9,7 +9,7 @@ setup-dev:
pip install -e .
test:
python -m unittest discover $(CURDIR)/$(TESTS)
python -m pytest --cov $(CURDIR)/$(NAME)
flake8:
@flake8 $(FILES_PY)
@ -23,5 +23,11 @@ isort:
validate: flake8 mypy isort
coverage:
python -m coverage run -m unittest discover $(CURDIR)/$(TESTS)
python -m coverage html
python -m pytest --cov $(CURDIR)/$(NAME) --cov-report html
clean:
rm -rf $(CURDIR)/build
rm -rf $(CURDIR)/dist
rm -rf $(CURDIR)/htmlcov
rm -rf $(CURDIR)/.coverage
rm -rf $(CURDIR)/$(NAME).egg-info

View File

@ -1,3 +1,6 @@
aioresponses==0.7.*
flake8==3.*
jedi==0.*
mypy==0.*
pytest==6.2.*
pytest-mock==3.6.*

0
tests/__init__.py Normal file
View File

34
tests/constants.py Normal file
View File

@ -0,0 +1,34 @@
from datetime import datetime
from ytrssil.types import ChannelData, VideoData
FEED_XML: str = '''
<?xml version="1.0" encoding="UTF-8"?>
<feed xmlns:yt="http://www.youtube.com/xml/schemas/2015">
<yt:channelId>channel_id</yt:channelId>
<title>channel_name</title>
<entry>
<yt:videoId>video_id</yt:videoId>
<yt:channelId>channel_id</yt:channelId>
<title>video_name</title>
<link rel="alternate" href="https://www.youtube.com/watch?v=video_id"/>
<published>1970-01-01T00:00:00+00:00</published>
</entry>
</feed>
'''
TEST_VIDEO_DATA: VideoData = {
'video_id': 'video_id',
'name': 'video_name',
'url': 'https://www.youtube.com/watch?v=video_id',
'channel_id': 'channel_id',
'channel_name': 'channel_name',
'timestamp': datetime.fromisoformat('1970-01-01T00:00:00+00:00'),
'watch_timestamp': None,
}
TEST_CHANNEL_DATA: ChannelData = {
'channel_id': 'channel_id',
'name': 'channel_name',
'new_videos': {}
}

17
tests/test_bindings.py Normal file
View File

@ -0,0 +1,17 @@
from ytrssil.bindings import setup_dependencies
from ytrssil.config import Configuration
from ytrssil.fetch import Fetcher, AioHttpFetcher
from ytrssil.parse import Parser, FeedparserParser
from ytrssil.repository import ChannelRepository, SqliteChannelRepository
def test_setup_dependencies() -> None:
injector = setup_dependencies()
config = injector.get_instance(Configuration)
assert isinstance(config, Configuration)
assert isinstance(
injector.get_instance(ChannelRepository),
SqliteChannelRepository,
)
assert isinstance(injector.get_instance(Fetcher), AioHttpFetcher)
assert isinstance(injector.get_instance(Parser), FeedparserParser)

83
tests/test_cli.py Normal file
View File

@ -0,0 +1,83 @@
from datetime import datetime
from unittest.mock import MagicMock
from pytest_mock import MockerFixture
from tests.constants import TEST_CHANNEL_DATA, TEST_VIDEO_DATA
from ytrssil import cli
from ytrssil.datatypes import Channel, Video
def test_user_query(mocker: MockerFixture) -> None:
def mock_query(input: bytes) -> tuple[bytes, bytes]:
videos = input.decode('UTF-8').split('\n')
return (videos[0].encode('UTF-8'), b'')
popen_mock: MagicMock = mocker.patch.object(cli, 'Popen')
attrs = {'communicate': mock_query}
communicate_mock = mocker.MagicMock()
communicate_mock.configure_mock(**attrs)
popen_mock.return_value = communicate_mock
videos = {
f'video_id_{i}': Video(
video_id=f'video_id_{i}',
name='video',
url='url',
timestamp=datetime.utcnow(),
channel_id='channel_id',
channel_name='channel',
)
for i in range(2)
}
ret = cli.user_query(videos=videos)
assert ret == [videos['video_id_0']]
def test_watch_videos(mocker: MockerFixture) -> None:
repository_mock = mocker.MagicMock()
update_video = mocker.MagicMock()
repository_mock.__enter__.return_value.update_video = update_video
fetcher_mock = mocker.MagicMock()
channel = Channel.from_dict(TEST_CHANNEL_DATA)
video = Video.from_dict(TEST_VIDEO_DATA)
channel.add_video(video)
fetcher_mock.fetch_new_videos.return_value = (
{channel.channel_id: channel},
{video.video_id: video},
)
query_mock = mocker.patch.object(cli, 'user_query')
query_mock.return_value = [video]
fork_mock = mocker.patch.object(cli, 'fork')
cli.watch_videos(repository_manager=repository_mock, fetcher=fetcher_mock)
fork_mock.assert_called_once()
update_video.assert_called_once() # repository is a context manager
def test_main_no_arg(mocker: MockerFixture) -> None:
mock = mocker.patch.object(cli, 'watch_videos')
cli.main(['ytrssil'])
assert mock.called_once
def test_main_watch_videos(mocker: MockerFixture) -> None:
mock = mocker.patch.object(cli, 'watch_videos')
cli.main(['ytrssil', 'watch_videos'])
assert mock.called_once
def test_main_history(mocker: MockerFixture) -> None:
mock = mocker.patch.object(cli, 'watch_history')
cli.main(['ytrssil', 'history'])
assert mock.called_once
def test_main_mark(mocker: MockerFixture) -> None:
mock = mocker.patch.object(cli, 'mark_as_watched')
cli.main(['ytrssil', 'mark', datetime.utcnow().isoformat()])
assert mock.called_once

96
tests/test_datatypes.py Normal file
View File

@ -0,0 +1,96 @@
from datetime import datetime
from tests.constants import TEST_CHANNEL_DATA, TEST_VIDEO_DATA
from ytrssil.datatypes import Channel, Video
def test_video_str() -> None:
string = str(Video.from_dict(TEST_VIDEO_DATA))
assert string == 'channel_name - video_name - video_id'
def test_channel_str() -> None:
channel = Channel.from_dict({
**TEST_CHANNEL_DATA,
'new_videos': {
'video_id': Video(
video_id='video_id',
name='video_name',
url='https://www.youtube.com/watch?v=video_id',
channel_id='channel_id',
channel_name='channel_name',
timestamp=datetime.fromisoformat('1970-01-01T00:00:00+00:00'),
watch_timestamp=None,
),
},
})
string = str(channel)
assert string == 'channel_name - 1'
def test_channel_add_new_video() -> None:
channel = Channel.from_dict(TEST_CHANNEL_DATA)
added_video = channel.add_video(Video(
video_id='video_id',
name='video_name',
url='https://www.youtube.com/watch?v=video_id',
channel_id='channel_id',
channel_name='channel_name',
timestamp=datetime.fromisoformat('1970-01-01T00:00:00+00:00'),
watch_timestamp=None,
))
assert added_video
assert list(channel.new_videos.keys()) == ['video_id']
def test_channel_add_existing_video() -> None:
channel = Channel.from_dict({
**TEST_CHANNEL_DATA,
'new_videos': {
'video_id': Video(
video_id='video_id',
name='video_name',
url='https://www.youtube.com/watch?v=video_id',
channel_id='channel_id',
channel_name='channel_name',
timestamp=datetime.fromisoformat('1970-01-01T00:00:00+00:00'),
watch_timestamp=None,
),
},
})
added_video = channel.add_video(Video(
video_id='video_id',
name='video_name',
url='https://www.youtube.com/watch?v=video_id',
channel_id='channel_id',
channel_name='channel_name',
timestamp=datetime.fromisoformat('1970-01-01T00:00:00+00:00'),
watch_timestamp=None,
))
assert not added_video
assert list(channel.new_videos.keys()) == ['video_id']
def test_channel_mark_video_as_watched() -> None:
video = Video(
video_id='video_id',
name='video_name',
url='https://www.youtube.com/watch?v=video_id',
channel_id='channel_id',
channel_name='channel_name',
timestamp=datetime.fromisoformat('1970-01-01T00:00:00+00:00'),
watch_timestamp=None,
)
channel = Channel.from_dict({
**TEST_CHANNEL_DATA,
'new_videos': {'video_id': video},
})
channel.mark_video_as_watched(video)
assert list(channel.new_videos.keys()) == []
assert list(channel.watched_videos.keys()) == ['video_id']

43
tests/test_fetch.py Normal file
View File

@ -0,0 +1,43 @@
from __future__ import annotations
from collections.abc import Iterable
from aioresponses import aioresponses
from tests.constants import FEED_XML, TEST_CHANNEL_DATA, TEST_VIDEO_DATA
from ytrssil.config import Configuration
from ytrssil.datatypes import Channel, Video
from ytrssil.fetch import AioHttpFetcher, Fetcher
def test_fetch_new_videos():
class MockFetcher(Fetcher):
def fetch_feeds(self, urls: Iterable[str]) -> Iterable[str]:
return [FEED_XML]
fetcher = MockFetcher()
channel = Channel.from_dict(TEST_CHANNEL_DATA)
video = Video.from_dict(TEST_VIDEO_DATA)
channel.add_video(video)
channels, new_videos = fetcher.fetch_new_videos(
config=Configuration(
feed_url_getter_type='static',
feed_urls=[''],
),
parser=lambda _: channel,
)
assert channels[channel.channel_id] == channel
assert new_videos[TEST_VIDEO_DATA['video_id']] == video
def test_aiohttpfetcher_fetch_feeds():
feed_url = 'test_url'
with aioresponses() as mocked:
mocked.get(
url=feed_url,
body=FEED_XML,
)
fetcher = AioHttpFetcher()
xml = fetcher.fetch_feeds([feed_url])
assert xml == [FEED_XML]

47
tests/test_parse.py Normal file
View File

@ -0,0 +1,47 @@
from unittest import TestCase
from unittest.mock import Mock
from inject import Binder, clear_and_configure
from tests.constants import FEED_XML, TEST_CHANNEL_DATA, TEST_VIDEO_DATA
from ytrssil.config import Configuration
from ytrssil.datatypes import Channel, Video
from ytrssil.exceptions import ChannelNotFound
from ytrssil.parse import create_feed_parser, FeedparserParser
from ytrssil.repository import ChannelRepository
def test_feedparser_channel_exists() -> None:
channel = Channel.from_dict(TEST_CHANNEL_DATA)
mock_repo = Mock()
mock_repo.get_channel.return_value = channel
parser = FeedparserParser(channel_repository=mock_repo)
assert parser(FEED_XML) == channel
def test_feedparser_new_channel() -> None:
channel = Channel.from_dict(TEST_CHANNEL_DATA)
channel.add_video(Video.from_dict(TEST_VIDEO_DATA))
mock_repo = Mock()
mock_repo.get_channel.side_effect = ChannelNotFound()
parser = FeedparserParser(channel_repository=mock_repo)
assert parser(FEED_XML) == channel
class TestCreateParser(TestCase):
def setUp(self) -> None:
clear_and_configure(self.inject)
def inject(self, binder: Binder) -> None:
binder.bind(ChannelRepository, Mock())
def test_create_feedparser_parser(self) -> None:
parser = create_feed_parser(Configuration(parser_type='feedparser'))
self.assertIsInstance(parser, FeedparserParser)
def test_fail_create_parser(self) -> None:
with self.assertRaises(Exception) as e:
create_feed_parser(Configuration(parser_type='fail'))
self.assertEqual('Unknown feed parser type: "fail"', e.exception)

View File

@ -19,3 +19,7 @@ def get_new_videos() -> list[Video]:
return new_videos
return list(_get_new_videos().values())
def get_new_video_count() -> int:
return len(get_new_videos())

View File

@ -1,4 +1,4 @@
from inject import Binder, Injector, configure, get_injector_or_die
from inject import Binder, Injector, clear_and_configure, get_injector_or_die
from ytrssil.config import Configuration
from ytrssil.fetch import Fetcher, create_fetcher
@ -9,13 +9,12 @@ from ytrssil.repository import ChannelRepository, create_channel_repository
def dependency_configuration(binder: Binder) -> None:
config = Configuration()
binder.bind(Configuration, config)
binder.bind_to_provider
binder.bind_to_constructor(ChannelRepository, create_channel_repository)
binder.bind_to_constructor(Fetcher, create_fetcher)
binder.bind_to_constructor(Parser, create_feed_parser)
def setup_dependencies() -> Injector:
configure(dependency_configuration)
clear_and_configure(dependency_configuration)
return get_injector_or_die()

View File

@ -111,11 +111,11 @@ def mark_as_watched(
return 0
def main() -> int:
def main(args: list[str] = argv) -> int:
setup_dependencies()
command: str
try:
command = argv[1]
command = args[1]
except IndexError:
command = 'watch'
@ -124,7 +124,7 @@ def main() -> int:
elif command == 'history':
return watch_history()
elif command == 'mark':
up_to_date = datetime.fromisoformat(argv[2])
up_to_date = datetime.fromisoformat(args[2])
return mark_as_watched(up_to_date=up_to_date)
else:
print(f'Unknown command "{command}"', file=stderr)

View File

@ -1,19 +1,30 @@
from __future__ import annotations
import os
from collections.abc import Iterator
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Any
from ytrssil.constants import config_dir
@dataclass
class Configuration:
feed_url_getter_type: str = 'file'
feed_urls: list[str] = field(default_factory=lambda: list())
channel_repository_type: str = 'sqlite'
fetcher_type: str = 'aiohttp'
feed_parser_type: str = 'feedparser'
parser_type: str = 'feedparser'
@classmethod
def from_dict(cls, config_dict: dict[str, Any]) -> Configuration:
return cls(**config_dict)
def get_feed_urls() -> Iterator[str]:
file_path = os.path.join(config_dir, 'feeds')
with open(file_path, 'r') as f:
for line in f:
yield line.strip()
def get_feed_urls(self) -> Iterator[str]:
if self.feed_url_getter_type == 'file':
file_path = os.path.join(config_dir, 'feeds')
with open(file_path, 'r') as f:
for line in f:
yield line.strip()
elif self.feed_url_getter_type == 'static':
yield from self.feed_urls

View File

@ -1,6 +1,9 @@
from __future__ import annotations
from dataclasses import dataclass, field
from datetime import datetime
from typing import Union
from typing import Optional
from ytrssil.types import ChannelData, VideoData
@dataclass
@ -11,17 +14,20 @@ class Video:
channel_id: str
channel_name: str
timestamp: datetime
watch_timestamp: Union[datetime, None] = None
watch_timestamp: Optional[datetime] = None
def __str__(self) -> str:
return f'{self.channel_name} - {self.name} - {self.video_id}'
@classmethod
def from_dict(cls, data: VideoData) -> Video:
return cls(**data)
@dataclass
class Channel:
channel_id: str
name: str
url: str
new_videos: dict[str, Video] = field(default_factory=lambda: dict())
watched_videos: dict[str, Video] = field(default_factory=lambda: dict())
@ -35,18 +41,18 @@ class Channel:
self.new_videos[video.video_id] = video
return True
def remove_old_videos(self) -> None:
vid_list: list[Video] = sorted(
self.watched_videos.values(),
key=lambda x: x.timestamp,
)
for video in vid_list[15:]:
self.watched_videos.pop(video.video_id)
def mark_video_as_watched(self, video: Video) -> None:
self.new_videos.pop(video.video_id)
self.watched_videos[video.video_id] = video
self.remove_old_videos()
def __str__(self) -> str:
return f'{self.name} - {len(self.new_videos)}'
@classmethod
def from_dict(cls, data: ChannelData) -> Channel:
return cls(
channel_id=data['channel_id'],
name=data['name'],
new_videos=data.get('new_videos', {}).copy(),
watched_videos=data.get('watched_videos', {}).copy(),
)

View File

@ -5,24 +5,26 @@ from collections.abc import Iterable
from aiohttp import ClientResponse, ClientSession
from inject import autoparams
from ytrssil.config import Configuration, get_feed_urls
from ytrssil.config import Configuration
from ytrssil.datatypes import Channel, Video
from ytrssil.parse import Parser
from ytrssil.repository import ChannelRepository
class Fetcher(metaclass=ABCMeta):
@abstractmethod
def fetch_feeds(self, urls: Iterable[str]) -> Iterable[str]:
def fetch_feeds(
self,
urls: Iterable[str],
) -> Iterable[str]: # pragma: no cover
pass
@autoparams('parser', 'repository')
@autoparams()
def fetch_new_videos(
self,
config: Configuration,
parser: Parser,
repository: ChannelRepository,
) -> tuple[dict[str, Channel], dict[str, Video]]:
feed_urls = get_feed_urls()
feed_urls = config.get_feed_urls()
channels: dict[str, Channel] = {}
new_videos: dict[str, Video] = {}
for feed in self.fetch_feeds(feed_urls):
@ -35,11 +37,11 @@ class Fetcher(metaclass=ABCMeta):
class AioHttpFetcher(Fetcher):
async def request(self, session: ClientSession, url: str) -> ClientResponse:
return await session.request(method='GET', url=url)
return await session.get(url=url)
async def async_fetch_feeds(self, urls: Iterable[str]) -> Iterable[str]:
async with ClientSession() as session:
responses: list[ClientResponse] = await gather(*[
responses: Iterable[ClientResponse] = await gather(*[
self.request(session, url) for url in urls
])
return [

View File

@ -16,7 +16,7 @@ class Parser(metaclass=ABCMeta):
self.repository = channel_repository
@abstractmethod
def __call__(self, feed_content: str) -> Channel:
def __call__(self, feed_content: str) -> Channel: # pragma: no cover
pass
@ -30,7 +30,6 @@ class FeedparserParser(Parser):
channel = Channel(
channel_id=channel_id,
name=d['feed']['title'],
url=d['feed']['link'],
)
self.repository.create_channel(channel)
@ -51,7 +50,7 @@ class FeedparserParser(Parser):
@autoparams()
def create_feed_parser(config: Configuration) -> Parser:
parser_type = config.feed_parser_type
parser_type = config.parser_type
if parser_type == 'feedparser':
return FeedparserParser()
else:

View File

@ -4,7 +4,8 @@ import os
from abc import ABCMeta, abstractmethod
from datetime import datetime
from sqlite3 import connect
from typing import Any, Union
from types import TracebackType
from typing import Any, Optional, Type
from inject import autoparams
@ -16,40 +17,48 @@ from ytrssil.exceptions import ChannelNotFound
class ChannelRepository(metaclass=ABCMeta):
@abstractmethod
def __enter__(self) -> ChannelRepository:
def __enter__(self) -> ChannelRepository: # pragma: no cover
pass
@abstractmethod
def __exit__(
self,
exc_type: Any,
exc_value: Any,
exc_traceback: Any,
) -> None:
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None: # pragma: no cover
pass
@abstractmethod
def get_channel(self, channel_id: str) -> Channel:
def get_channel(self, channel_id: str) -> Channel: # pragma: no cover
pass
@abstractmethod
def get_all_channels(self) -> list[Channel]:
def get_all_channels(self) -> list[Channel]: # pragma: no cover
pass
@abstractmethod
def get_watched_videos(self) -> dict[str, Video]:
def get_watched_videos(self) -> dict[str, Video]: # pragma: no cover
pass
@abstractmethod
def create_channel(self, channel: Channel) -> None:
def create_channel(self, channel: Channel) -> None: # pragma: no cover
pass
@abstractmethod
def add_new_video(self, channel: Channel, video: Video) -> None:
def add_new_video(
self,
channel: Channel,
video: Video,
) -> None: # pragma: no cover
pass
@abstractmethod
def update_video(self, video: Video, watch_timestamp: datetime) -> None:
def update_video(
self,
video: Video,
watch_timestamp: datetime,
) -> None: # pragma: no cover
pass
@ -65,7 +74,7 @@ class SqliteChannelRepository(ChannelRepository):
cursor.execute('PRAGMA foreign_keys = ON')
cursor.execute(
'CREATE TABLE IF NOT EXISTS channels ('
'channel_id VARCHAR PRIMARY KEY, name VARCHAR, url VARCHAR UNIQUE)'
'channel_id VARCHAR PRIMARY KEY, name VARCHAR)'
)
cursor.execute(
'CREATE TABLE IF NOT EXISTS videos ('
@ -82,9 +91,9 @@ class SqliteChannelRepository(ChannelRepository):
def __exit__(
self,
exc_type: Any,
exc_value: Any,
exc_traceback: Any,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
self.connection.close()
@ -92,7 +101,6 @@ class SqliteChannelRepository(ChannelRepository):
return {
'channel_id': channel.channel_id,
'name': channel.name,
'url': channel.url,
}
def channel_data_to_channel(
@ -102,7 +110,6 @@ class SqliteChannelRepository(ChannelRepository):
channel = Channel(
channel_id=channel_data[0],
name=channel_data[1],
url=channel_data[2],
)
for video in self.get_videos_for_channel(channel):
if video.watch_timestamp is not None:
@ -113,7 +120,7 @@ class SqliteChannelRepository(ChannelRepository):
return channel
def video_to_params(self, video: Video) -> dict[str, Any]:
watch_timestamp: Union[str, None]
watch_timestamp: Optional[str]
if video.watch_timestamp is not None:
watch_timestamp = video.watch_timestamp.isoformat()
else:
@ -134,7 +141,7 @@ class SqliteChannelRepository(ChannelRepository):
channel_id: str,
channel_name: str,
) -> Video:
watch_timestamp: Union[datetime, None]
watch_timestamp: Optional[datetime]
if video_data[4] is not None:
watch_timestamp = datetime.fromisoformat(video_data[4])
else:
@ -193,7 +200,7 @@ class SqliteChannelRepository(ChannelRepository):
def get_watched_videos(self) -> dict[str, Video]:
cursor = self.connection.cursor()
cursor.execute(
'SELECT video_id, videos.name, videos.url, timestamp, '
'SELECT video_id, videos.name, url, timestamp, '
'watch_timestamp, channels.channel_id, channels.name FROM videos '
'LEFT JOIN channels ON channels.channel_id=videos.channel_id WHERE '
'watch_timestamp IS NOT NULL ORDER BY timestamp'
@ -211,7 +218,7 @@ class SqliteChannelRepository(ChannelRepository):
def create_channel(self, channel: Channel) -> None:
cursor = self.connection.cursor()
cursor.execute(
'INSERT INTO channels VALUES (:channel_id, :name, :url)',
'INSERT INTO channels VALUES (:channel_id, :name)',
self.channel_to_params(channel),
)
self.connection.commit()
@ -220,7 +227,7 @@ class SqliteChannelRepository(ChannelRepository):
cursor = self.connection.cursor()
cursor.execute(
'UPDATE channels SET channel_id = :channel_id, name = :name, '
'url = :url WHERE channel_id=:channel_id',
'WHERE channel_id=:channel_id',
self.channel_to_params(channel),
)
self.connection.commit()

18
ytrssil/types.py Normal file
View File

@ -0,0 +1,18 @@
from datetime import datetime
from typing import Optional, TypedDict
class VideoData(TypedDict):
video_id: str
name: str
url: str
channel_id: str
channel_name: str
timestamp: datetime
watch_timestamp: Optional[datetime]
class ChannelData(TypedDict):
channel_id: str
name: str
new_videos: dict