diff --git a/lib/consumer/express.js b/lib/consumer/express.js index 959f649c..7f056449 100644 --- a/lib/consumer/express.js +++ b/lib/consumer/express.js @@ -41,30 +41,35 @@ module.exports = function (_config) { connect(req, res) }) - var transport = (provider, res, session) => (data) => { + var transport = (provider, req, res, session) => (data) => { if (!provider.callback) { res.end(qs.stringify(data)) } else if (!provider.transport || provider.transport === 'querystring') { - res.redirect(`${provider.callback}?${qs.stringify(data)}`) + redirect(req, res, `${provider.callback}?${qs.stringify(data)}`) } else if (provider.transport === 'session') { session.response = data - res.redirect(provider.callback) + redirect(req, res, provider.callback) } } + var redirect = (req, res, url) => + typeof req.session.save === 'function' + ? req.session.save(() => res.redirect(url)) + : res.redirect(url) + function connect (req, res) { var session = req.session.grant var provider = config.provider(app.config, session) - var response = transport(provider, res, session) + var response = transport(provider, req, res, session) if (provider.oauth === 1) { oauth1.request(provider) .then(({body}) => { session.request = body oauth1.authorize(provider, body) - .then((url) => res.redirect(url)) + .then((url) => redirect(req, res, url)) .catch(response) }) .catch(response) @@ -74,7 +79,7 @@ module.exports = function (_config) { session.state = provider.state session.nonce = provider.nonce oauth2.authorize(provider) - .then((url) => res.redirect(url)) + .then((url) => redirect(req, res, url)) .catch(response) } @@ -86,7 +91,7 @@ module.exports = function (_config) { function callback (req, res) { var session = req.session.grant || {} var provider = config.provider(app.config, session) - var response = transport(provider, res, session) + var response = transport(provider, req, res, session) if (provider.oauth === 1) { oauth1.access(provider, session.request, req.query)