Refactor app to use enum constants for endpoints

Simplifying routing a bit helps me to check which endpoint the user is
visiting. Also will make it a lot easier down the road to rewrite,
whenever I get around to renaming "autocomplete" to a more appropriate
name (like "suggestion" or something).
This commit is contained in:
Ben Busby 2021-11-16 18:46:37 -07:00
parent b0733fd74a
commit ba7409f230
No known key found for this signature in database
GPG Key ID: 339B7B7EB5333D14
8 changed files with 83 additions and 45 deletions

View File

@ -1,3 +1,4 @@
from app.models.endpoint import Endpoint
from app.request import VALID_PARAMS, MAPS_URL from app.request import VALID_PARAMS, MAPS_URL
from app.utils.misc import read_config_bool from app.utils.misc import read_config_bool
from app.utils.results import * from app.utils.results import *
@ -250,7 +251,7 @@ class Filter:
element['src'] = BLANK_B64 element['src'] = BLANK_B64
return return
element['src'] = 'element?url=' + self.encrypt_path( element['src'] = f'{Endpoint.element}?url=' + self.encrypt_path(
src, src,
is_element=True) + '&type=' + urlparse.quote(mime) is_element=True) + '&type=' + urlparse.quote(mime)
@ -385,7 +386,8 @@ class Filter:
if len(urls) != 2: if len(urls) != 2:
continue continue
img_url = urlparse.unquote(urls[0].replace('/imgres?imgurl=', '')) img_url = urlparse.unquote(urls[0].replace(
f'/{Endpoint.imgres}?imgurl=', ''))
try: try:
# Try to strip out only the necessary part of the web page link # Try to strip out only the necessary part of the web page link

22
app/models/endpoint.py Normal file
View File

@ -0,0 +1,22 @@
from enum import Enum
class Endpoint(Enum):
autocomplete = 'autocomplete'
home = 'home'
healthz = 'healthz'
session = 'session'
config = 'config'
opensearch = 'opensearch.xml'
search = 'search'
search_html = 'search.html'
url = 'url'
imgres = 'imgres'
element = 'element'
window = 'window'
def __repr__(self):
return self.value
def in_path(self, path: str) -> bool:
return path.startswith(self.value)

View File

@ -1,6 +1,5 @@
import argparse import argparse
import base64 import base64
import html
import io import io
import json import json
import pickle import pickle
@ -12,6 +11,7 @@ from functools import wraps
import waitress import waitress
from app import app from app import app
from app.models.config import Config from app.models.config import Config
from app.models.endpoint import Endpoint
from app.request import Request, TorError from app.request import Request, TorError
from app.utils.bangs import resolve_bang from app.utils.bangs import resolve_bang
from app.utils.misc import read_config_bool, get_client_ip from app.utils.misc import read_config_bool, get_client_ip
@ -20,6 +20,7 @@ from app.utils.results import bold_search_terms
from app.utils.search import * from app.utils.search import *
from app.utils.session import generate_user_key, valid_user_session from app.utils.session import generate_user_key, valid_user_session
from bs4 import BeautifulSoup as bsoup from bs4 import BeautifulSoup as bsoup
from enum import Enum
from flask import jsonify, make_response, request, redirect, render_template, \ from flask import jsonify, make_response, request, redirect, render_template, \
send_file, session, url_for send_file, session, url_for
from requests import exceptions, get from requests import exceptions, get
@ -110,10 +111,17 @@ def before_request_func():
session['config'] = default_config session['config'] = default_config
session['uuid'] = str(uuid.uuid4()) session['uuid'] = str(uuid.uuid4())
session['key'] = generate_user_key() session['key'] = generate_user_key()
return redirect(url_for(
'session_check', # Skip checking for session on /autocomplete searches,
session_id=session['uuid'], # since they can be done from the browser search bar (aka
follow=request.url), code=307) # no ability to initialize a session)
if not Endpoint.autocomplete.in_path(request.path):
return redirect(url_for(
'session_check',
session_id=session['uuid'],
follow=request.url), code=307)
else:
g.user_config = Config(**session['config'])
elif 'cookies_disabled' not in request.args: elif 'cookies_disabled' not in request.args:
# Set session as permanent # Set session as permanent
session.permanent = True session.permanent = True
@ -158,17 +166,17 @@ def unknown_page(e):
return redirect(g.app_location) return redirect(g.app_location)
@app.route('/healthz', methods=['GET']) @app.route(f'/{Endpoint.healthz}', methods=['GET'])
def healthz(): def healthz():
return '' return ''
@app.route('/home', methods=['GET']) @app.route(f'/{Endpoint.home}', methods=['GET'])
def home(): def home():
return redirect(url_for('.index')) return redirect(url_for('.index'))
@app.route('/session/<session_id>', methods=['GET', 'PUT', 'POST']) @app.route(f'{Endpoint.session}/<session_id>', methods=['GET', 'PUT', 'POST'])
def session_check(session_id): def session_check(session_id):
if 'uuid' in session and session['uuid'] == session_id: if 'uuid' in session and session['uuid'] == session_id:
session['valid'] = True session['valid'] = True
@ -210,7 +218,7 @@ def index():
version_number=app.config['VERSION_NUMBER']) version_number=app.config['VERSION_NUMBER'])
@app.route('/opensearch.xml', methods=['GET']) @app.route(f'/{Endpoint.opensearch}', methods=['GET'])
def opensearch(): def opensearch():
opensearch_url = g.app_location opensearch_url = g.app_location
if opensearch_url.endswith('/'): if opensearch_url.endswith('/'):
@ -230,7 +238,7 @@ def opensearch():
), 200, {'Content-Disposition': 'attachment; filename="opensearch.xml"'} ), 200, {'Content-Disposition': 'attachment; filename="opensearch.xml"'}
@app.route('/search.html', methods=['GET']) @app.route(f'/{Endpoint.search_html}', methods=['GET'])
def search_html(): def search_html():
search_url = g.app_location search_url = g.app_location
if search_url.endswith('/'): if search_url.endswith('/'):
@ -238,8 +246,7 @@ def search_html():
return render_template('search.html', url=search_url) return render_template('search.html', url=search_url)
@app.route('/autocomplete', methods=['GET', 'POST']) @app.route(f'/{Endpoint.autocomplete}', methods=['GET', 'POST'])
@session_required
def autocomplete(): def autocomplete():
ac_var = 'WHOOGLE_AUTOCOMPLETE' ac_var = 'WHOOGLE_AUTOCOMPLETE'
if os.getenv(ac_var) and not read_config_bool(ac_var): if os.getenv(ac_var) and not read_config_bool(ac_var):
@ -272,7 +279,7 @@ def autocomplete():
]) ])
@app.route('/search', methods=['GET', 'POST']) @app.route(f'/{Endpoint.search}', methods=['GET', 'POST'])
@session_required @session_required
@auth_required @auth_required
def search(): def search():
@ -288,7 +295,7 @@ def search():
# Redirect to home if invalid/blank search # Redirect to home if invalid/blank search
if not query: if not query:
return redirect('/') return redirect(url_for('.index'))
# Generate response and number of external elements from the page # Generate response and number of external elements from the page
try: try:
@ -348,7 +355,7 @@ def search():
search_util.search_type else '')), resp_code search_util.search_type else '')), resp_code
@app.route('/config', methods=['GET', 'POST', 'PUT']) @app.route(f'/{Endpoint.config}', methods=['GET', 'POST', 'PUT'])
@session_required @session_required
@auth_required @auth_required
def config(): def config():
@ -387,7 +394,7 @@ def config():
return redirect(url_for('.index'), code=403) return redirect(url_for('.index'), code=403)
@app.route('/url', methods=['GET']) @app.route(f'/{Endpoint.url}', methods=['GET'])
@session_required @session_required
@auth_required @auth_required
def url(): def url():
@ -403,14 +410,14 @@ def url():
error_message='Unable to resolve query: ' + q) error_message='Unable to resolve query: ' + q)
@app.route('/imgres') @app.route(f'/{Endpoint.imgres}')
@session_required @session_required
@auth_required @auth_required
def imgres(): def imgres():
return redirect(request.args.get('imgurl')) return redirect(request.args.get('imgurl'))
@app.route('/element') @app.route(f'/{Endpoint.element}')
@session_required @session_required
@auth_required @auth_required
def element(): def element():
@ -433,7 +440,7 @@ def element():
return send_file(io.BytesIO(empty_gif), mimetype='image/gif') return send_file(io.BytesIO(empty_gif), mimetype='image/gif')
@app.route('/window') @app.route(f'/{Endpoint.window}')
@auth_required @auth_required
def window(): def window():
get_body = g.user_request.send(base_url=request.args.get('location')).text get_body = g.user_request.send(base_url=request.args.get('location')).text

View File

@ -1,3 +1,4 @@
from app.models.endpoint import Endpoint
from bs4 import BeautifulSoup, NavigableString from bs4 import BeautifulSoup, NavigableString
import html import html
import os import os
@ -177,7 +178,7 @@ def append_nojs(result: BeautifulSoup) -> None:
""" """
nojs_link = BeautifulSoup(features='html.parser').new_tag('a') nojs_link = BeautifulSoup(features='html.parser').new_tag('a')
nojs_link['href'] = '/window?location=' + result['href'] nojs_link['href'] = f'/{Endpoint.window}?location=' + result['href']
nojs_link.string = ' NoJS Link' nojs_link.string = ' NoJS Link'
result.append(nojs_link) result.append(nojs_link)

View File

@ -1,12 +1,15 @@
from app.models.endpoint import Endpoint
def test_autocomplete_get(client): def test_autocomplete_get(client):
rv = client.get('/autocomplete?q=green+eggs+and') rv = client.get(f'/{Endpoint.autocomplete}?q=green+eggs+and')
assert rv._status_code == 200 assert rv._status_code == 200
assert len(rv.data) >= 1 assert len(rv.data) >= 1
assert b'green eggs and ham' in rv.data assert b'green eggs and ham' in rv.data
def test_autocomplete_post(client): def test_autocomplete_post(client):
rv = client.post('/autocomplete', data=dict(q='the+cat+in+the')) rv = client.post(f'/{Endpoint.autocomplete}', data=dict(q='the+cat+in+the'))
assert rv._status_code == 200 assert rv._status_code == 200
assert len(rv.data) >= 1 assert len(rv.data) >= 1
assert b'the cat in the hat' in rv.data assert b'the cat in the hat' in rv.data

View File

@ -1,6 +1,7 @@
from cryptography.fernet import Fernet from cryptography.fernet import Fernet
from app import app from app import app
from app.models.endpoint import Endpoint
from app.utils.session import generate_user_key, valid_user_session from app.utils.session import generate_user_key, valid_user_session
@ -37,13 +38,13 @@ def test_query_decryption(client):
rv = client.get('/') rv = client.get('/')
cookie = rv.headers['Set-Cookie'] cookie = rv.headers['Set-Cookie']
rv = client.get('/search?q=test+1', headers={'Cookie': cookie}) rv = client.get(f'/{Endpoint.search}?q=test+1', headers={'Cookie': cookie})
assert rv._status_code == 200 assert rv._status_code == 200
with client.session_transaction() as session: with client.session_transaction() as session:
assert valid_user_session(session) assert valid_user_session(session)
rv = client.get('/search?q=test+2', headers={'Cookie': cookie}) rv = client.get(f'/{Endpoint.search}?q=test+2', headers={'Cookie': cookie})
assert rv._status_code == 200 assert rv._status_code == 200
with client.session_transaction() as session: with client.session_transaction() as session:

View File

@ -1,5 +1,6 @@
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
from app.filter import Filter from app.filter import Filter
from app.models.endpoint import Endpoint
from app.utils.session import generate_user_key from app.utils.session import generate_user_key
from datetime import datetime from datetime import datetime
from dateutil.parser import * from dateutil.parser import *
@ -30,7 +31,7 @@ def get_search_results(data):
def test_get_results(client): def test_get_results(client):
rv = client.get('/search?q=test') rv = client.get(f'/{Endpoint.search}?q=test')
assert rv._status_code == 200 assert rv._status_code == 200
# Depending on the search, there can be more # Depending on the search, there can be more
@ -41,7 +42,7 @@ def test_get_results(client):
def test_post_results(client): def test_post_results(client):
rv = client.post('/search', data=dict(q='test')) rv = client.post(f'/{Endpoint.search}', data=dict(q='test'))
assert rv._status_code == 200 assert rv._status_code == 200
# Depending on the search, there can be more # Depending on the search, there can be more
@ -52,7 +53,7 @@ def test_post_results(client):
def test_translate_search(client): def test_translate_search(client):
rv = client.post('/search', data=dict(q='translate hola')) rv = client.post(f'/{Endpoint.search}', data=dict(q='translate hola'))
assert rv._status_code == 200 assert rv._status_code == 200
# Pretty weak test, but better than nothing # Pretty weak test, but better than nothing
@ -62,7 +63,7 @@ def test_translate_search(client):
def test_block_results(client): def test_block_results(client):
rv = client.post('/search', data=dict(q='pinterest')) rv = client.post(f'/{Endpoint.search}', data=dict(q='pinterest'))
assert rv._status_code == 200 assert rv._status_code == 200
has_pinterest = False has_pinterest = False
@ -74,10 +75,10 @@ def test_block_results(client):
assert has_pinterest assert has_pinterest
demo_config['block'] = 'pinterest.com' demo_config['block'] = 'pinterest.com'
rv = client.post('/config', data=demo_config) rv = client.post(f'/{Endpoint.config}', data=demo_config)
assert rv._status_code == 302 assert rv._status_code == 302
rv = client.post('/search', data=dict(q='pinterest')) rv = client.post(f'/{Endpoint.search}', data=dict(q='pinterest'))
assert rv._status_code == 200 assert rv._status_code == 200
for link in BeautifulSoup(rv.data, 'html.parser').find_all('a', href=True): for link in BeautifulSoup(rv.data, 'html.parser').find_all('a', href=True):
@ -106,7 +107,7 @@ def test_recent_results(client):
} }
for time, num_days in times.items(): for time, num_days in times.items():
rv = client.post('/search', data=dict(q='test :' + time)) rv = client.post(f'/{Endpoint.search}', data=dict(q='test :' + time))
result_divs = get_search_results(rv.data) result_divs = get_search_results(rv.data)
current_date = datetime.now() current_date = datetime.now()

View File

@ -1,4 +1,5 @@
from app import app from app import app
from app.models.endpoint import Endpoint
import json import json
@ -11,47 +12,47 @@ def test_main(client):
def test_search(client): def test_search(client):
rv = client.get('/search?q=test') rv = client.get(f'/{Endpoint.search}?q=test')
assert rv._status_code == 200 assert rv._status_code == 200
def test_feeling_lucky(client): def test_feeling_lucky(client):
rv = client.get('/search?q=!%20test') rv = client.get(f'/{Endpoint.search}?q=!%20test')
assert rv._status_code == 303 assert rv._status_code == 303
def test_ddg_bang(client): def test_ddg_bang(client):
# Bang at beginning of query # Bang at beginning of query
rv = client.get('/search?q=!gh%20whoogle') rv = client.get(f'/{Endpoint.search}?q=!gh%20whoogle')
assert rv._status_code == 302 assert rv._status_code == 302
assert rv.headers.get('Location').startswith('https://github.com') assert rv.headers.get('Location').startswith('https://github.com')
# Move bang to end of query # Move bang to end of query
rv = client.get('/search?q=github%20!w') rv = client.get(f'/{Endpoint.search}?q=github%20!w')
assert rv._status_code == 302 assert rv._status_code == 302
assert rv.headers.get('Location').startswith('https://en.wikipedia.org') assert rv.headers.get('Location').startswith('https://en.wikipedia.org')
# Move bang to middle of query # Move bang to middle of query
rv = client.get('/search?q=big%20!r%20chungus') rv = client.get(f'/{Endpoint.search}?q=big%20!r%20chungus')
assert rv._status_code == 302 assert rv._status_code == 302
assert rv.headers.get('Location').startswith('https://www.reddit.com') assert rv.headers.get('Location').startswith('https://www.reddit.com')
# Move '!' to end of the bang # Move '!' to end of the bang
rv = client.get('/search?q=gitlab%20w!') rv = client.get(f'/{Endpoint.search}?q=gitlab%20w!')
assert rv._status_code == 302 assert rv._status_code == 302
assert rv.headers.get('Location').startswith('https://en.wikipedia.org') assert rv.headers.get('Location').startswith('https://en.wikipedia.org')
# Ensure bang is case insensitive # Ensure bang is case insensitive
rv = client.get('/search?q=!GH%20whoogle') rv = client.get(f'/{Endpoint.search}?q=!GH%20whoogle')
assert rv._status_code == 302 assert rv._status_code == 302
assert rv.headers.get('Location').startswith('https://github.com') assert rv.headers.get('Location').startswith('https://github.com')
def test_config(client): def test_config(client):
rv = client.post('/config', data=demo_config) rv = client.post(f'/{Endpoint.config}', data=demo_config)
assert rv._status_code == 302 assert rv._status_code == 302
rv = client.get('/config') rv = client.get(f'/{Endpoint.config}')
assert rv._status_code == 200 assert rv._status_code == 200
config = json.loads(rv.data) config = json.loads(rv.data)
@ -62,15 +63,15 @@ def test_config(client):
app.config['CONFIG_DISABLE'] = 1 app.config['CONFIG_DISABLE'] = 1
dark_mod = not demo_config['dark'] dark_mod = not demo_config['dark']
demo_config['dark'] = dark_mod demo_config['dark'] = dark_mod
rv = client.post('/config', data=demo_config) rv = client.post(f'/{Endpoint.config}', data=demo_config)
assert rv._status_code == 403 assert rv._status_code == 403
rv = client.get('/config') rv = client.get(f'/{Endpoint.config}')
config = json.loads(rv.data) config = json.loads(rv.data)
assert config['dark'] != dark_mod assert config['dark'] != dark_mod
def test_opensearch(client): def test_opensearch(client):
rv = client.get('/opensearch.xml') rv = client.get(f'/{Endpoint.opensearch}')
assert rv._status_code == 200 assert rv._status_code == 200
assert '<ShortName>Whoogle</ShortName>' in str(rv.data) assert '<ShortName>Whoogle</ShortName>' in str(rv.data)