Skip to content

Commit

Permalink
Add support for compound file extensions.
Browse files Browse the repository at this point in the history
Created a new class ReaderTree that is an infinitely
nested defaultdict containing components of the extension.
See comments on PR getpelican#2816.
  • Loading branch information
holden-nelson committed Jan 7, 2022
1 parent 1b87ef6 commit 2e1b37b
Show file tree
Hide file tree
Showing 2 changed files with 237 additions and 6 deletions.
160 changes: 155 additions & 5 deletions pelican/readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
import logging
import os
import re
from collections import OrderedDict
from collections import OrderedDict, defaultdict
from html import escape
from html.parser import HTMLParser
from io import StringIO
from functools import reduce
import operator

import docutils
import docutils.core
Expand Down Expand Up @@ -496,8 +498,8 @@ class Readers(FileStampDataCacher):

def __init__(self, settings=None, cache_name=''):
self.settings = settings or {}
self.readers = {}
self.reader_classes = {}
self.readers = ReaderTree()
self.reader_classes = ReaderTree()

for cls in [BaseReader] + BaseReader.__subclasses__():
if not cls.enabled:
Expand Down Expand Up @@ -542,8 +544,7 @@ def read_file(self, base_path, path, content_class=Page, fmt=None,
source_path, content_class.__name__)

if not fmt:
_, ext = os.path.splitext(os.path.basename(path))
fmt = ext[1:]
fmt = self.readers.get_format(path)

if fmt not in self.readers:
raise TypeError(
Expand Down Expand Up @@ -746,3 +747,152 @@ def parse_path_metadata(source_path, settings=None, process=None):
v = process(k, v)
metadata[k] = v
return metadata


class ReaderTree():

def __init__(self):
self.tree_dd = ReaderTree._rec_dd()

def __str__(self):
return str(ReaderTree._rec_dd_to_dict(self.tree_dd))

def __iter__(self):
for key in ReaderTree._rec_get_next_key(self.tree_dd):
yield key

def __setitem__(self, key, value):
components = reversed(key.split('.'))
reduce(operator.getitem, components, self.tree_dd)[''] = value

def __getitem__(self, key):
components = reversed(key.split('.'))
value = reduce(operator.getitem, components, self.tree_dd)
if value:
return value['']
else:
raise KeyError

def __delitem__(self, key):
value = ReaderTree._rec_del_item(self.tree_dd, key)
if not value:
raise KeyError

def __contains__(self, item):
try:
self[item]
return True
except KeyError:
return False

def __len__(self):
return len(list(self.keys()))

def keys(self):
return self.__iter__()

def values(self):
for value in ReaderTree._rec_get_next_value(self.tree_dd):
yield value

def items(self):
return zip(self.keys(), self.values())

def get(self, key):
return self[key]

def setdefault(self, key, value):
if key in self:
return self[key]
else:
self[key] = value
return value

def clear(self):
self.tree_dd.clear()

def pop(self, key, default=None):
if key in self:
value = self[key]
del self[key]
return value
elif default:
return default
else:
raise KeyError

def copy(self):
return self.tree_dd.copy()

def update(self, d):
for key, value in d.items():
self[key] = value

def get_format(self, filename):
ext = ReaderTree._rec_get_fmt_from_filename(self.tree_dd, filename)
return ext[1:]

def as_dict(self):
return ReaderTree._rec_dd_to_dict(self.tree_dd)

@staticmethod
def _rec_dd():
return defaultdict(ReaderTree._rec_dd)

@staticmethod
def _rec_dd_to_dict(dd):
d = dict(dd)

for key, value in d.items():
if type(value) == defaultdict:
d[key] = ReaderTree._rec_dd_to_dict(value)

return d

@staticmethod
def _rec_get_next_key(d):
for key in d:
if key != '':
if '' in d[key]:
yield key
if type(d[key]) == defaultdict:
for component in ReaderTree._rec_get_next_key(d[key]):
yield '.'.join([component, key])

@staticmethod
def _rec_get_next_value(d):
for key, value in d.items():
if key == '':
yield value
else:
if type(d[key]) == defaultdict:
yield from ReaderTree._rec_get_next_value(d[key])

@staticmethod
def _rec_del_item(d, intended_key):
if intended_key in d:
value = d[intended_key]['']
del d[intended_key]['']
return value
else:
for key in d:
if type(d[key]) == defaultdict:
ReaderTree._rec_del_item(d[key], intended_key)

return None

@staticmethod
def _rec_get_fmt_from_filename(d, filename):
if '.' in filename:
file, ext = os.path.splitext(filename)
fmt = ext[1:]

if fmt in d:
next_component = ReaderTree._rec_get_fmt_from_filename(d[fmt], file)
return '.'.join([next_component, fmt])
elif '' in d:
return fmt
else:
raise TypeError("No valid extension found")
else:
return ''
83 changes: 82 additions & 1 deletion pelican/tests/test_readers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from unittest.mock import patch
from unittest.mock import patch, Mock

from pelican import readers
from pelican.tests.support import get_settings, unittest
Expand Down Expand Up @@ -76,6 +76,18 @@ def test_readfile_unknown_extension(self):
with self.assertRaises(TypeError):
self.read_file(path='article_with_metadata.unknownextension')

with self.assertRaises(TypeError):
self.read_file(path='article_with.compound.extension')

def test_readfile_compound_extension(self):
CompoundReader = Mock()

# throws type error b/c of mock
with self.assertRaises(TypeError):
self.read_file(path='article_with.compound.extension',
READERS={'compound.extension': CompoundReader})
CompoundReader.read.assert_called_with('article_with.compound.extension')

def test_readfile_path_metadata_implicit_dates(self):
test_file = 'article_with_metadata_implicit_dates.html'
page = self.read_file(path=test_file, DEFAULT_DATE='fs')
Expand Down Expand Up @@ -918,3 +930,72 @@ def test_article_with_inline_svg(self):
'title': 'Article with an inline SVG',
}
self.assertDictHasSubset(page.metadata, expected)


class ReaderTreeTest(unittest.TestCase):

def setUp(self):

readers_and_exts = {
'BaseReader': ['static'],
'RstReader': ['rst'],
'HtmlReader': ['htm', 'html'],
'MDReader': ['md', 'mk', 'mkdown', 'mkd'],
'MDeepReader': ['md.html'],
'FooReader': ['foo.bar.baz.yaz']
}

self.reader_classes = readers.ReaderTree()

for reader, exts in readers_and_exts.items():
for ext in exts:
self.reader_classes[ext] = reader

def test_correct_mapping_generated(self):
expected_mapping = {
'static': {'': 'BaseReader'},
'rst': {'': 'RstReader'},
'htm': {'': 'HtmlReader'},
'html': {
'': 'HtmlReader',
'md': {'': 'MDeepReader'}
},
'md': {'': 'MDReader'},
'mk': {'': 'MDReader'},
'mkdown': {'': 'MDReader'},
'mkd': {'': 'MDReader'},
'yaz': {
'baz': {
'bar': {
'foo': {'': 'FooReader'}}}}}

self.assertEqual(expected_mapping, self.reader_classes.as_dict())

def test_containment(self):
self.assertTrue('md.html' in self.reader_classes)
self.assertTrue('html' in self.reader_classes)
self.assertFalse('txt' in self.reader_classes)

def test_deletion(self):
self.assertTrue('rst' in self.reader_classes)
del self.reader_classes['rst']
self.assertFalse('rst' in self.reader_classes)

def test_update(self):
self.reader_classes.update({
'new.ext': 'NewExtReader',
'txt': 'TxtReader'
})
self.assertEqual(self.reader_classes['new.ext'], 'NewExtReader')
self.assertEqual(self.reader_classes['txt'], 'TxtReader')

def test_get_format(self):
html_ext = self.reader_classes.get_format('text.html')
md_ext = self.reader_classes.get_format('another.md')
compound_ext = self.reader_classes.get_format('compound.md.html')
no_ext = self.reader_classes.get_format('no_extension')

self.assertEqual(html_ext, 'html')
self.assertEqual(md_ext, 'md')
self.assertEqual(compound_ext, 'md.html')
self.assertEqual(no_ext, '')

0 comments on commit 2e1b37b

Please sign in to comment.