summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--main.py13
-rw-r--r--test/test_api.py8
2 files changed, 15 insertions, 6 deletions
diff --git a/main.py b/main.py
index 5c6211d..a4766b1 100644
--- a/main.py
+++ b/main.py
@@ -35,7 +35,7 @@ def create_app(test_config=None):
35 CORS(app, resources={r"/api/*": {"origins": utils.get_allow_origin_header_value()}}) 35 CORS(app, resources={r"/api/*": {"origins": utils.get_allow_origin_header_value()}})
36 36
37 @app.route('/download/<string:key>/<string:filename>') 37 @app.route('/download/<string:key>/<string:filename>')
38 def download_file(key:str, filename:str): 38 def download_file(key: str, filename:str):
39 if filename != secure_filename(filename): 39 if filename != secure_filename(filename):
40 return redirect(url_for('upload_file')) 40 return redirect(url_for('upload_file'))
41 41
@@ -173,11 +173,12 @@ def create_app(test_config=None):
173 class APIDownload(Resource): 173 class APIDownload(Resource):
174 def get(self, key: str, filename: str): 174 def get(self, key: str, filename: str):
175 complete_path, filepath = is_valid_api_download_file(filename, key) 175 complete_path, filepath = is_valid_api_download_file(filename, key)
176 176 # Make sure the file is NOT deleted on HEAD requests
177 @after_this_request 177 if request.method == 'GET':
178 def remove_file(response): 178 @after_this_request
179 os.remove(complete_path) 179 def remove_file(response):
180 return response 180 os.remove(complete_path)
181 return response
181 182
182 return send_from_directory(app.config['UPLOAD_FOLDER'], filepath) 183 return send_from_directory(app.config['UPLOAD_FOLDER'], filepath)
183 184
diff --git a/test/test_api.py b/test/test_api.py
index 532ceb9..2029820 100644
--- a/test/test_api.py
+++ b/test/test_api.py
@@ -151,6 +151,10 @@ class Mat2APITestCase(unittest.TestCase):
151 error = json.loads(request.data.decode('utf-8'))['message'] 151 error = json.loads(request.data.decode('utf-8'))['message']
152 self.assertEqual(error, 'The file hash does not match') 152 self.assertEqual(error, 'The file hash does not match')
153 153
154 request = self.app.head(data['download_link'])
155 self.assertEqual(request.status_code, 200)
156 self.assertEqual(request.headers['Content-Length'], '633')
157
154 request = self.app.get(data['download_link']) 158 request = self.app.get(data['download_link'])
155 self.assertEqual(request.status_code, 200) 159 self.assertEqual(request.status_code, 200)
156 160
@@ -210,6 +214,10 @@ class Mat2APITestCase(unittest.TestCase):
210 self.assertIn(response['mime'], 'application/zip') 214 self.assertIn(response['mime'], 'application/zip')
211 self.assertEqual(response['meta_after'], {}) 215 self.assertEqual(response['meta_after'], {})
212 216
217 request = self.app.head(response['download_link'])
218 self.assertEqual(request.status_code, 200)
219 self.assertEqual(request.headers['Content-Length'], '1596')
220
213 request = self.app.get(response['download_link']) 221 request = self.app.get(response['download_link'])
214 zip_response = zipfile.ZipFile(BytesIO(request.data)) 222 zip_response = zipfile.ZipFile(BytesIO(request.data))
215 self.assertEquals(2, len(zip_response.namelist())) 223 self.assertEquals(2, len(zip_response.namelist()))