summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--libmat2/__init__.py6
-rw-r--r--libmat2/office.py14
-rwxr-xr-xmat217
-rw-r--r--tests/test_policy.py11
4 files changed, 25 insertions, 23 deletions
diff --git a/libmat2/__init__.py b/libmat2/__init__.py
index bf4e813..8a5b064 100644
--- a/libmat2/__init__.py
+++ b/libmat2/__init__.py
@@ -2,6 +2,7 @@
2 2
3import os 3import os
4import collections 4import collections
5from enum import Enum
5import importlib 6import importlib
6from typing import Dict, Optional 7from typing import Dict, Optional
7 8
@@ -62,3 +63,8 @@ def check_dependencies() -> dict:
62 ret[value] = False # pragma: no cover 63 ret[value] = False # pragma: no cover
63 64
64 return ret 65 return ret
66
67class UnknownMemberPolicy(Enum):
68 ABORT = 'abort'
69 OMIT = 'omit'
70 KEEP = 'keep'
diff --git a/libmat2/office.py b/libmat2/office.py
index 29100df..60c5478 100644
--- a/libmat2/office.py
+++ b/libmat2/office.py
@@ -9,7 +9,7 @@ from typing import Dict, Set, Pattern
9 9
10import xml.etree.ElementTree as ET # type: ignore 10import xml.etree.ElementTree as ET # type: ignore
11 11
12from . import abstract, parser_factory 12from . import abstract, parser_factory, UnknownMemberPolicy
13 13
14# Make pyflakes happy 14# Make pyflakes happy
15assert Set 15assert Set
@@ -37,8 +37,8 @@ class ArchiveBasedAbstractParser(abstract.AbstractParser):
37 files_to_omit = set() # type: Set[Pattern] 37 files_to_omit = set() # type: Set[Pattern]
38 38
39 # what should the parser do if it encounters an unknown file in 39 # what should the parser do if it encounters an unknown file in
40 # the archive? valid policies are 'abort', 'omit', 'keep' 40 # the archive?
41 unknown_member_policy = 'abort' # type: str 41 unknown_member_policy = UnknownMemberPolicy.ABORT # type: UnknownMemberPolicy
42 42
43 def __init__(self, filename): 43 def __init__(self, filename):
44 super().__init__(filename) 44 super().__init__(filename)
@@ -81,10 +81,6 @@ class ArchiveBasedAbstractParser(abstract.AbstractParser):
81 def remove_all(self) -> bool: 81 def remove_all(self) -> bool:
82 # pylint: disable=too-many-branches 82 # pylint: disable=too-many-branches
83 83
84 if self.unknown_member_policy not in ['omit', 'keep', 'abort']:
85 logging.error("The policy %s is invalid.", self.unknown_member_policy)
86 raise ValueError
87
88 with zipfile.ZipFile(self.filename) as zin,\ 84 with zipfile.ZipFile(self.filename) as zin,\
89 zipfile.ZipFile(self.output_filename, 'w') as zout: 85 zipfile.ZipFile(self.output_filename, 'w') as zout:
90 86
@@ -113,11 +109,11 @@ class ArchiveBasedAbstractParser(abstract.AbstractParser):
113 # supported files that we want to clean then add 109 # supported files that we want to clean then add
114 tmp_parser, mtype = parser_factory.get_parser(full_path) # type: ignore 110 tmp_parser, mtype = parser_factory.get_parser(full_path) # type: ignore
115 if not tmp_parser: 111 if not tmp_parser:
116 if self.unknown_member_policy == 'omit': 112 if self.unknown_member_policy == UnknownMemberPolicy.OMIT:
117 logging.warning("In file %s, omitting unknown element %s (format: %s)", 113 logging.warning("In file %s, omitting unknown element %s (format: %s)",
118 self.filename, item.filename, mtype) 114 self.filename, item.filename, mtype)
119 continue 115 continue
120 elif self.unknown_member_policy == 'keep': 116 elif self.unknown_member_policy == UnknownMemberPolicy.KEEP:
121 logging.warning("In file %s, keeping unknown element %s (format: %s)", 117 logging.warning("In file %s, keeping unknown element %s (format: %s)",
122 self.filename, item.filename, mtype) 118 self.filename, item.filename, mtype)
123 else: 119 else:
diff --git a/mat2 b/mat2
index 2a8ef46..0aba8d1 100755
--- a/mat2
+++ b/mat2
@@ -10,7 +10,8 @@ import multiprocessing
10import logging 10import logging
11 11
12try: 12try:
13 from libmat2 import parser_factory, UNSUPPORTED_EXTENSIONS, check_dependencies 13 from libmat2 import (parser_factory, UNSUPPORTED_EXTENSIONS, check_dependencies,
14 UnknownMemberPolicy)
14except ValueError as e: 15except ValueError as e:
15 print(e) 16 print(e)
16 sys.exit(1) 17 sys.exit(1)
@@ -42,8 +43,8 @@ def create_arg_parser():
42 parser.add_argument('-V', '--verbose', action='store_true', 43 parser.add_argument('-V', '--verbose', action='store_true',
43 help='show more verbose status information') 44 help='show more verbose status information')
44 parser.add_argument('--unknown-members', metavar='policy', default='abort', 45 parser.add_argument('--unknown-members', metavar='policy', default='abort',
45 help='how to handle unknown members of archive-style files ' + 46 help='how to handle unknown members of archive-style files (policy should' +
46 '(policy should be abort, omit, or keep)') 47 ' be one of: ' + ', '.join([x.value for x in UnknownMemberPolicy]) + ')')
47 48
48 49
49 info = parser.add_mutually_exclusive_group() 50 info = parser.add_mutually_exclusive_group()
@@ -70,7 +71,7 @@ def show_meta(filename: str):
70 except UnicodeEncodeError: 71 except UnicodeEncodeError:
71 print(" %s: harmful content" % k) 72 print(" %s: harmful content" % k)
72 73
73def clean_meta(params: Tuple[str, bool, str]) -> bool: 74def clean_meta(params: Tuple[str, bool, UnknownMemberPolicy]) -> bool:
74 filename, is_lightweight, unknown_member_policy = params 75 filename, is_lightweight, unknown_member_policy = params
75 if not __check_file(filename, os.R_OK|os.W_OK): 76 if not __check_file(filename, os.R_OK|os.W_OK):
76 return False 77 return False
@@ -137,15 +138,13 @@ def main():
137 return 0 138 return 0
138 139
139 else: 140 else:
140 if args.unknown_members == 'keep': 141 unknown_member_policy = UnknownMemberPolicy(args.unknown_members)
142 if unknown_member_policy == UnknownMemberPolicy.KEEP:
141 logging.warning('Keeping unknown member files may leak metadata in the resulting file!') 143 logging.warning('Keeping unknown member files may leak metadata in the resulting file!')
142 elif args.unknown_members not in ['omit', 'abort']:
143 logging.warning('Undefined policy for handling unknown member files: "%s"',
144 args.unknown_members)
145 p = multiprocessing.Pool() 144 p = multiprocessing.Pool()
146 mode = (args.lightweight is True) 145 mode = (args.lightweight is True)
147 l = zip(__get_files_recursively(args.files), itertools.repeat(mode), 146 l = zip(__get_files_recursively(args.files), itertools.repeat(mode),
148 itertools.repeat(args.unknown_members)) 147 itertools.repeat(unknown_member_policy))
149 148
150 ret = list(p.imap_unordered(clean_meta, list(l))) 149 ret = list(p.imap_unordered(clean_meta, list(l)))
151 return 0 if all(ret) else -1 150 return 0 if all(ret) else -1
diff --git a/tests/test_policy.py b/tests/test_policy.py
index 39282b1..5a8447b 100644
--- a/tests/test_policy.py
+++ b/tests/test_policy.py
@@ -4,28 +4,29 @@ import unittest
4import shutil 4import shutil
5import os 5import os
6 6
7from libmat2 import office 7from libmat2 import office, UnknownMemberPolicy
8 8
9class TestPolicy(unittest.TestCase): 9class TestPolicy(unittest.TestCase):
10 def test_policy_omit(self): 10 def test_policy_omit(self):
11 shutil.copy('./tests/data/embedded.docx', './tests/data/clean.docx') 11 shutil.copy('./tests/data/embedded.docx', './tests/data/clean.docx')
12 p = office.MSOfficeParser('./tests/data/clean.docx') 12 p = office.MSOfficeParser('./tests/data/clean.docx')
13 p.unknown_member_policy = 'omit' 13 p.unknown_member_policy = UnknownMemberPolicy.OMIT
14 self.assertTrue(p.remove_all()) 14 self.assertTrue(p.remove_all())
15 os.remove('./tests/data/clean.docx') 15 os.remove('./tests/data/clean.docx')
16 os.remove('./tests/data/clean.cleaned.docx')
16 17
17 def test_policy_keep(self): 18 def test_policy_keep(self):
18 shutil.copy('./tests/data/embedded.docx', './tests/data/clean.docx') 19 shutil.copy('./tests/data/embedded.docx', './tests/data/clean.docx')
19 p = office.MSOfficeParser('./tests/data/clean.docx') 20 p = office.MSOfficeParser('./tests/data/clean.docx')
20 p.unknown_member_policy = 'keep' 21 p.unknown_member_policy = UnknownMemberPolicy.KEEP
21 self.assertTrue(p.remove_all()) 22 self.assertTrue(p.remove_all())
22 os.remove('./tests/data/clean.docx') 23 os.remove('./tests/data/clean.docx')
24 os.remove('./tests/data/clean.cleaned.docx')
23 25
24 def test_policy_unknown(self): 26 def test_policy_unknown(self):
25 shutil.copy('./tests/data/embedded.docx', './tests/data/clean.docx') 27 shutil.copy('./tests/data/embedded.docx', './tests/data/clean.docx')
26 p = office.MSOfficeParser('./tests/data/clean.docx') 28 p = office.MSOfficeParser('./tests/data/clean.docx')
27 p.unknown_member_policy = 'unknown_policy_name_totally_invalid'
28 with self.assertRaises(ValueError): 29 with self.assertRaises(ValueError):
29 p.remove_all() 30 p.unknown_member_policy = UnknownMemberPolicy('unknown_policy_name_totally_invalid')
30 os.remove('./tests/data/clean.docx') 31 os.remove('./tests/data/clean.docx')
31 32