diff --git a/cps/oauth.py b/cps/oauth.py
new file mode 100644
index 00000000..679e7f31
--- /dev/null
+++ b/cps/oauth.py
@@ -0,0 +1,134 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+from flask import session
+from flask_dance.consumer.backend.sqla import SQLAlchemyBackend, first, _get_real_user
+from sqlalchemy.orm.exc import NoResultFound
+
+
+class OAuthBackend(SQLAlchemyBackend):
+ """
+ Stores and retrieves OAuth tokens using a relational database through
+ the `SQLAlchemy`_ ORM.
+
+ .. _SQLAlchemy: http://www.sqlalchemy.org/
+ """
+ def __init__(self, model, session,
+ user=None, user_id=None, user_required=None, anon_user=None,
+ cache=None):
+ super(OAuthBackend, self).__init__(model, session, user, user_id, user_required, anon_user, cache)
+
+ def get(self, blueprint, user=None, user_id=None):
+ if blueprint.name + '_oauth_token' in session and session[blueprint.name + '_oauth_token'] != '':
+ return session[blueprint.name + '_oauth_token']
+ # check cache
+ cache_key = self.make_cache_key(blueprint=blueprint, user=user, user_id=user_id)
+ token = self.cache.get(cache_key)
+ if token:
+ return token
+
+ # if not cached, make database queries
+ query = (
+ self.session.query(self.model)
+ .filter_by(provider=blueprint.name)
+ )
+ uid = first([user_id, self.user_id, blueprint.config.get("user_id")])
+ u = first(_get_real_user(ref, self.anon_user)
+ for ref in (user, self.user, blueprint.config.get("user")))
+
+ use_provider_user_id = False
+ if blueprint.name + '_oauth_user_id' in session and session[blueprint.name + '_oauth_user_id'] != '':
+ query = query.filter_by(provider_user_id=session[blueprint.name + '_oauth_user_id'])
+ use_provider_user_id = True
+
+ if self.user_required and not u and not uid and not use_provider_user_id:
+ #raise ValueError("Cannot get OAuth token without an associated user")
+ return None
+ # check for user ID
+ if hasattr(self.model, "user_id") and uid:
+ query = query.filter_by(user_id=uid)
+ # check for user (relationship property)
+ elif hasattr(self.model, "user") and u:
+ query = query.filter_by(user=u)
+ # if we have the property, but not value, filter by None
+ elif hasattr(self.model, "user_id"):
+ query = query.filter_by(user_id=None)
+ # run query
+ try:
+ token = query.one().token
+ except NoResultFound:
+ token = None
+
+ # cache the result
+ self.cache.set(cache_key, token)
+
+ return token
+
+ def set(self, blueprint, token, user=None, user_id=None):
+ uid = first([user_id, self.user_id, blueprint.config.get("user_id")])
+ u = first(_get_real_user(ref, self.anon_user)
+ for ref in (user, self.user, blueprint.config.get("user")))
+
+ if self.user_required and not u and not uid:
+ raise ValueError("Cannot set OAuth token without an associated user")
+
+ # if there was an existing model, delete it
+ existing_query = (
+ self.session.query(self.model)
+ .filter_by(provider=blueprint.name)
+ )
+ # check for user ID
+ has_user_id = hasattr(self.model, "user_id")
+ if has_user_id and uid:
+ existing_query = existing_query.filter_by(user_id=uid)
+ # check for user (relationship property)
+ has_user = hasattr(self.model, "user")
+ if has_user and u:
+ existing_query = existing_query.filter_by(user=u)
+ # queue up delete query -- won't be run until commit()
+ existing_query.delete()
+ # create a new model for this token
+ kwargs = {
+ "provider": blueprint.name,
+ "token": token,
+ }
+ if has_user_id and uid:
+ kwargs["user_id"] = uid
+ if has_user and u:
+ kwargs["user"] = u
+ self.session.add(self.model(**kwargs))
+ # commit to delete and add simultaneously
+ self.session.commit()
+ # invalidate cache
+ self.cache.delete(self.make_cache_key(
+ blueprint=blueprint, user=user, user_id=user_id
+ ))
+
+ def delete(self, blueprint, user=None, user_id=None):
+ query = (
+ self.session.query(self.model)
+ .filter_by(provider=blueprint.name)
+ )
+ uid = first([user_id, self.user_id, blueprint.config.get("user_id")])
+ u = first(_get_real_user(ref, self.anon_user)
+ for ref in (user, self.user, blueprint.config.get("user")))
+
+ if self.user_required and not u and not uid:
+ raise ValueError("Cannot delete OAuth token without an associated user")
+
+ # check for user ID
+ if hasattr(self.model, "user_id") and uid:
+ query = query.filter_by(user_id=uid)
+ # check for user (relationship property)
+ elif hasattr(self.model, "user") and u:
+ query = query.filter_by(user=u)
+ # if we have the property, but not value, filter by None
+ elif hasattr(self.model, "user_id"):
+ query = query.filter_by(user_id=None)
+ # run query
+ query.delete()
+ self.session.commit()
+ # invalidate cache
+ self.cache.delete(self.make_cache_key(
+ blueprint=blueprint, user=user, user_id=user_id,
+ ))
diff --git a/cps/templates/config_edit.html b/cps/templates/config_edit.html
index 8c0e8281..2d77f038 100644
--- a/cps/templates/config_edit.html
+++ b/cps/templates/config_edit.html
@@ -185,6 +185,34 @@
+
+
+ {{_('Obtain GitHub OAuth Credentail')}}
+
+
diff --git a/cps/ub.py b/cps/ub.py
index dc0ffa75..541b12bf 100644
--- a/cps/ub.py
+++ b/cps/ub.py
@@ -23,6 +23,7 @@ from sqlalchemy import exc
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import *
from flask_login import AnonymousUserMixin
+from flask_dance.consumer.backend.sqla import OAuthConsumerMixin
import sys
import os
import logging
@@ -197,6 +198,12 @@ class User(UserBase, Base):
mature_content = Column(Boolean, default=True)
+class OAuth(OAuthConsumerMixin, Base):
+ provider_user_id = Column(String(256))
+ user_id = Column(Integer, ForeignKey(User.id))
+ user = relationship(User)
+
+
# Class for anonymous user is derived from User base and completly overrides methods and properties for the
# anonymous user
class Anonymous(AnonymousUserMixin, UserBase):
@@ -337,6 +344,12 @@ class Settings(Base):
config_use_ldap = Column(Boolean)
config_ldap_provider_url = Column(String)
config_ldap_dn = Column(String)
+ config_use_github_oauth = Column(Boolean)
+ config_github_oauth_client_id = Column(String)
+ config_github_oauth_client_secret = Column(String)
+ config_use_google_oauth = Column(Boolean)
+ config_google_oauth_client_id = Column(String)
+ config_google_oauth_client_secret = Column(String)
config_mature_content_tags = Column(String)
config_logfile = Column(String)
config_ebookconverter = Column(Integer, default=0)
@@ -414,6 +427,12 @@ class Config:
self.config_use_ldap = data.config_use_ldap
self.config_ldap_provider_url = data.config_ldap_provider_url
self.config_ldap_dn = data.config_ldap_dn
+ self.config_use_github_oauth = data.config_use_github_oauth
+ self.config_github_oauth_client_id = data.config_github_oauth_client_id
+ self.config_github_oauth_client_secret = data.config_github_oauth_client_secret
+ self.config_use_google_oauth = data.config_use_google_oauth
+ self.config_google_oauth_client_id = data.config_google_oauth_client_id
+ self.config_google_oauth_client_secret = data.config_google_oauth_client_secret
if data.config_mature_content_tags:
self.config_mature_content_tags = data.config_mature_content_tags
else:
@@ -722,6 +741,22 @@ def migrate_Database():
conn.execute("ALTER TABLE Settings ADD column `config_updatechannel` INTEGER DEFAULT 0")
session.commit()
+ try:
+ session.query(exists().where(Settings.config_use_github_oauth)).scalar()
+ except exc.OperationalError:
+ conn = engine.connect()
+ conn.execute("ALTER TABLE Settings ADD column `config_use_github_oauth` INTEGER DEFAULT 0")
+ conn.execute("ALTER TABLE Settings ADD column `config_github_oauth_client_id` String DEFAULT ''")
+ conn.execute("ALTER TABLE Settings ADD column `config_github_oauth_client_secret` String DEFAULT ''")
+ session.commit()
+ try:
+ session.query(exists().where(Settings.config_use_google_oauth)).scalar()
+ except exc.OperationalError:
+ conn = engine.connect()
+ conn.execute("ALTER TABLE Settings ADD column `config_use_google_oauth` INTEGER DEFAULT 0")
+ conn.execute("ALTER TABLE Settings ADD column `config_google_oauth_client_id` String DEFAULT ''")
+ conn.execute("ALTER TABLE Settings ADD column `config_google_oauth_client_secret` String DEFAULT ''")
+ session.commit()
# Remove login capability of user Guest
conn = engine.connect()
conn.execute("UPDATE user SET password='' where nickname = 'Guest' and password !=''")
diff --git a/cps/web.py b/cps/web.py
index 3eb9ff04..25e44dfe 100644
--- a/cps/web.py
+++ b/cps/web.py
@@ -24,7 +24,7 @@
import mimetypes
import logging
from logging.handlers import RotatingFileHandler
-from flask import (Flask, render_template, request, Response, redirect,
+from flask import (Flask, session, render_template, request, Response, redirect,
url_for, send_from_directory, make_response, g, flash,
abort, Markup)
from flask import __version__ as flaskVersion
@@ -78,6 +78,11 @@ import time
import server
from reverseproxy import ReverseProxied
from updater import updater_thread
+from flask_dance.contrib.github import make_github_blueprint, github
+from flask_dance.contrib.google import make_google_blueprint, google
+from flask_dance.consumer import oauth_authorized, oauth_error
+from sqlalchemy.orm.exc import NoResultFound
+from oauth import OAuthBackend
try:
from googleapiclient.errors import HttpError
@@ -142,6 +147,7 @@ EXTENSIONS_AUDIO = {'mp3', 'm4a', 'm4b'}
# EXTENSIONS_READER = set(['txt', 'pdf', 'epub', 'zip', 'cbz', 'tar', 'cbt'] + (['rar','cbr'] if rar_support else []))
+oauth_check = {}
@@ -369,6 +375,35 @@ def remote_login_required(f):
return inner
+def github_oauth_required(f):
+ @wraps(f)
+ def inner(*args, **kwargs):
+ if config.config_use_github_oauth:
+ return f(*args, **kwargs)
+ if request.is_xhr:
+ data = {'status': 'error', 'message': 'Not Found'}
+ response = make_response(json.dumps(data, ensure_ascii=False))
+ response.headers["Content-Type"] = "application/json; charset=utf-8"
+ return response, 404
+ abort(404)
+
+ return inner
+
+
+def google_oauth_required(f):
+ @wraps(f)
+ def inner(*args, **kwargs):
+ if config.config_use_google_oauth:
+ return f(*args, **kwargs)
+ if request.is_xhr:
+ data = {'status': 'error', 'message': 'Not Found'}
+ response = make_response(json.dumps(data, ensure_ascii=False))
+ response.headers["Content-Type"] = "application/json; charset=utf-8"
+ return response, 404
+ abort(404)
+
+ return inner
+
# custom jinja filters
# pagination links in jinja
@@ -2261,6 +2296,7 @@ def register():
try:
ub.session.add(content)
ub.session.commit()
+ register_user_with_oauth(content)
helper.send_registration_mail(to_save["email"],to_save["nickname"], password)
except Exception:
ub.session.rollback()
@@ -2276,7 +2312,8 @@ def register():
flash(_(u"This username or e-mail address is already in use."), category="error")
return render_title_template('register.html', title=_(u"register"), page="register")
- return render_title_template('register.html', title=_(u"register"), page="register")
+ register_user_with_oauth()
+ return render_title_template('register.html', config=config, title=_(u"register"), page="register")
@app.route('/login', methods=['GET', 'POST'])
@@ -2312,8 +2349,7 @@ def login():
# if next_url is None or not is_safe_url(next_url):
next_url = url_for('index')
- return render_title_template('login.html', title=_(u"login"), next_url=next_url,
- remote_login=config.config_remote_login, page="login")
+ return render_title_template('login.html', title=_(u"login"), next_url=next_url, config=config, page="login")
@app.route('/logout')
@@ -2321,6 +2357,7 @@ def login():
def logout():
if current_user is not None and current_user.is_authenticated:
logout_user()
+ logout_oauth_user()
return redirect(url_for('login'))
@@ -2752,6 +2789,7 @@ def profile():
downloads = list()
languages = speaking_language()
translations = babel.list_translations() + [LC('en')]
+ oauth_status = get_oauth_status()
for book in content.downloads:
downloadBook = db.session.query(db.Books).filter(db.Books.id == book.book_id).first()
if downloadBook:
@@ -2814,11 +2852,11 @@ def profile():
ub.session.rollback()
flash(_(u"Found an existing account for this e-mail address."), category="error")
return render_title_template("user_edit.html", content=content, downloads=downloads,
- title=_(u"%(name)s's profile", name=current_user.nickname))
+ title=_(u"%(name)s's profile", name=current_user.nickname, registered_oauth=oauth_check, oauth_status=oauth_status))
flash(_(u"Profile updated"), category="success")
return render_title_template("user_edit.html", translations=translations, profile=1, languages=languages,
content=content, downloads=downloads, title=_(u"%(name)s's profile",
- name=current_user.nickname), page="me")
+ name=current_user.nickname), page="me", registered_oauth=oauth_check, oauth_status=oauth_status)
@app.route("/admin/view")
@@ -3083,6 +3121,29 @@ def configuration_helper(origin):
content.config_goodreads_api_secret = to_save["config_goodreads_api_secret"]
if "config_updater" in to_save:
content.config_updatechannel = int(to_save["config_updater"])
+
+ # GitHub OAuth configuration
+ content.config_use_github_oauth = ("config_use_github_oauth" in to_save and to_save["config_use_github_oauth"] == "on")
+ if "config_github_oauth_client_id" in to_save:
+ content.config_github_oauth_client_id = to_save["config_github_oauth_client_id"]
+ if "config_github_oauth_client_secret" in to_save:
+ content.config_github_oauth_client_secret = to_save["config_github_oauth_client_secret"]
+
+ if content.config_github_oauth_client_id != config.config_github_oauth_client_id or \
+ content.config_github_oauth_client_secret != config.config_github_oauth_client_secret:
+ reboot_required = True
+
+ # Google OAuth configuration
+ content.config_use_google_oauth = ("config_use_google_oauth" in to_save and to_save["config_use_google_oauth"] == "on")
+ if "config_google_oauth_client_id" in to_save:
+ content.config_google_oauth_client_id = to_save["config_google_oauth_client_id"]
+ if "config_google_oauth_client_secret" in to_save:
+ content.config_google_oauth_client_secret = to_save["config_google_oauth_client_secret"]
+
+ if content.config_google_oauth_client_id != config.config_google_oauth_client_id or \
+ content.config_google_oauth_client_secret != config.config_google_oauth_client_secret:
+ reboot_required = True
+
if "config_log_level" in to_save:
content.config_log_level = int(to_save["config_log_level"])
if content.config_logfile != to_save["config_logfile"]:
@@ -3977,3 +4038,267 @@ def convert_bookformat(book_id):
else:
flash(_(u"There was an error converting this book: %(res)s", res=rtn), category="error")
return redirect(request.environ["HTTP_REFERER"])
+
+
+def register_oauth_blueprint(blueprint, show_name):
+ if blueprint.name != "":
+ oauth_check[blueprint.name] = show_name
+
+
+def register_user_with_oauth(user=None):
+ all_oauth = {}
+ for oauth in oauth_check.keys():
+ if oauth + '_oauth_user_id' in session and session[oauth + '_oauth_user_id'] != '':
+ all_oauth[oauth] = oauth_check[oauth]
+ if len(all_oauth.keys()) == 0:
+ return
+ if user is None:
+ flash(_(u"Register with %s" % ", ".join(list(all_oauth.values()))), category="success")
+ else:
+ for oauth in all_oauth.keys():
+ # Find this OAuth token in the database, or create it
+ query = ub.session.query(ub.OAuth).filter_by(
+ provider=oauth,
+ provider_user_id=session[oauth + "_oauth_user_id"],
+ )
+ try:
+ oauth = query.one()
+ oauth.user_id = user.id
+ except NoResultFound:
+ # no found, return error
+ return
+ try:
+ ub.session.commit()
+ except Exception as e:
+ app.logger.exception(e)
+ ub.session.rollback()
+
+
+def logout_oauth_user():
+ for oauth in oauth_check.keys():
+ if oauth + '_oauth_user_id' in session:
+ session.pop(oauth + '_oauth_user_id')
+
+
+github_blueprint = make_github_blueprint(
+ client_id=config.config_github_oauth_client_id,
+ client_secret=config.config_github_oauth_client_secret,
+ redirect_to="github_login",)
+
+google_blueprint = make_google_blueprint(
+ client_id=config.config_google_oauth_client_id,
+ client_secret=config.config_google_oauth_client_secret,
+ redirect_to="google_login",
+ scope=[
+ "https://www.googleapis.com/auth/plus.me",
+ "https://www.googleapis.com/auth/userinfo.email",
+ ]
+)
+
+app.register_blueprint(google_blueprint, url_prefix="/login")
+app.register_blueprint(github_blueprint, url_prefix='/login')
+
+github_blueprint.backend = OAuthBackend(ub.OAuth, ub.session, user=current_user, user_required=True)
+google_blueprint.backend = OAuthBackend(ub.OAuth, ub.session, user=current_user, user_required=True)
+
+
+if config.config_use_github_oauth:
+ register_oauth_blueprint(github_blueprint, 'GitHub')
+if config.config_use_google_oauth:
+ register_oauth_blueprint(google_blueprint, 'Google')
+
+
+@oauth_authorized.connect_via(github_blueprint)
+def github_logged_in(blueprint, token):
+ if not token:
+ flash(_("Failed to log in with GitHub."), category="error")
+ return False
+
+ resp = blueprint.session.get("/user")
+ if not resp.ok:
+ flash(_("Failed to fetch user info from GitHub."), category="error")
+ return False
+
+ github_info = resp.json()
+ github_user_id = str(github_info["id"])
+ return oauth_update_token(blueprint, token, github_user_id)
+
+
+@oauth_authorized.connect_via(google_blueprint)
+def google_logged_in(blueprint, token):
+ if not token:
+ flash(_("Failed to log in with Google."), category="error")
+ return False
+
+ resp = blueprint.session.get("/oauth2/v2/userinfo")
+ if not resp.ok:
+ flash(_("Failed to fetch user info from Google."), category="error")
+ return False
+
+ google_info = resp.json()
+ google_user_id = str(google_info["id"])
+
+ return oauth_update_token(blueprint, token, google_user_id)
+
+
+def oauth_update_token(blueprint, token, provider_user_id):
+ session[blueprint.name + "_oauth_user_id"] = provider_user_id
+ session[blueprint.name + "_oauth_token"] = token
+
+ # Find this OAuth token in the database, or create it
+ query = ub.session.query(ub.OAuth).filter_by(
+ provider=blueprint.name,
+ provider_user_id=provider_user_id,
+ )
+ try:
+ oauth = query.one()
+ # update token
+ oauth.token = token
+ except NoResultFound:
+ oauth = ub.OAuth(
+ provider=blueprint.name,
+ provider_user_id=provider_user_id,
+ token=token,
+ )
+ try:
+ ub.session.add(oauth)
+ ub.session.commit()
+ except Exception as e:
+ app.logger.exception(e)
+ ub.session.rollback()
+
+ # Disable Flask-Dance's default behavior for saving the OAuth token
+ return False
+
+
+def bind_oauth_or_register(provider, provider_user_id, redirect_url):
+ query = ub.session.query(ub.OAuth).filter_by(
+ provider=provider,
+ provider_user_id=provider_user_id,
+ )
+ try:
+ oauth = query.one()
+ # already bind with user, just login
+ if oauth.user:
+ login_user(oauth.user)
+ return redirect(url_for('index'))
+ else:
+ # bind to current user
+ if current_user and current_user.is_authenticated:
+ oauth.user = current_user
+ try:
+ ub.session.add(oauth)
+ ub.session.commit()
+ except Exception as e:
+ app.logger.exception(e)
+ ub.session.rollback()
+ return redirect(url_for('register'))
+ except NoResultFound:
+ return redirect(url_for(redirect_url))
+
+
+def get_oauth_status():
+ status = []
+ query = ub.session.query(ub.OAuth).filter_by(
+ user_id=current_user.id,
+ )
+ try:
+ oauths = query.all()
+ for oauth in oauths:
+ status.append(oauth.provider)
+ return status
+ except NoResultFound:
+ return None
+
+
+def unlink_oauth(provider):
+ if request.host_url + 'me' != request.referrer:
+ pass
+ query = ub.session.query(ub.OAuth).filter_by(
+ provider=provider,
+ user_id=current_user.id,
+ )
+ try:
+ oauth = query.one()
+ if current_user and current_user.is_authenticated:
+ oauth.user = current_user
+ try:
+ ub.session.delete(oauth)
+ ub.session.commit()
+ logout_oauth_user()
+ flash(_("Unlink to %(oauth)s success.", oauth=oauth_check[provider]), category="success")
+ except Exception as e:
+ app.logger.exception(e)
+ ub.session.rollback()
+ flash(_("Unlink to %(oauth)s failed.", oauth=oauth_check[provider]), category="error")
+ except NoResultFound:
+ app.logger.warning("oauth %s for user %d not fount" % (provider, current_user.id))
+ flash(_("Not linked to %(oauth)s.", oauth=oauth_check[provider]), category="error")
+ return redirect(url_for('profile'))
+
+
+# notify on OAuth provider error
+@oauth_error.connect_via(github_blueprint)
+def github_error(blueprint, error, error_description=None, error_uri=None):
+ msg = (
+ "OAuth error from {name}! "
+ "error={error} description={description} uri={uri}"
+ ).format(
+ name=blueprint.name,
+ error=error,
+ description=error_description,
+ uri=error_uri,
+ )
+ flash(msg, category="error")
+
+
+@app.route('/github')
+@github_oauth_required
+def github_login():
+ if not github.authorized:
+ return redirect(url_for('github.login'))
+ account_info = github.get('/user')
+ if account_info.ok:
+ account_info_json = account_info.json()
+ return bind_oauth_or_register(github_blueprint.name, account_info_json['id'], 'github.login')
+ flash(_(u"GitHub Oauth error, please retry later."), category="error")
+ return redirect(url_for('login'))
+
+
+@app.route('/unlink/github', methods=["GET"])
+@login_required
+def github_login_unlink():
+ return unlink_oauth(github_blueprint.name)
+
+
+@app.route('/google')
+@google_oauth_required
+def google_login():
+ if not google.authorized:
+ return redirect(url_for("google.login"))
+ resp = google.get("/oauth2/v2/userinfo")
+ if resp.ok:
+ account_info_json = resp.json()
+ return bind_oauth_or_register(google_blueprint.name, account_info_json['id'], 'google.login')
+ flash(_(u"Google Oauth error, please retry later."), category="error")
+ return redirect(url_for('login'))
+
+
+@oauth_error.connect_via(google_blueprint)
+def google_error(blueprint, error, error_description=None, error_uri=None):
+ msg = (
+ "OAuth error from {name}! "
+ "error={error} description={description} uri={uri}"
+ ).format(
+ name=blueprint.name,
+ error=error,
+ description=error_description,
+ uri=error_uri,
+ )
+ flash(msg, category="error")
+
+
+@app.route('/unlink/google', methods=["GET"])
+@login_required
+def google_login_unlink():
+ return unlink_oauth(google_blueprint.name)
diff --git a/requirements.txt b/requirements.txt
index 3fb23ea3..2b13eb54 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -13,3 +13,5 @@ SQLAlchemy>=1.1.0
tornado>=4.1
Wand>=0.4.4
unidecode>=0.04.19
+flask-dance>=0.13.0
+sqlalchemy_utils>=0.33.5