Skip to content

Commit 05bef47

Browse files
mattgotteinerMatt Gotteiner
andauthored
Add support for an optional login and document level access control system. (Azure-Samples#624)
* conditional login button * fixing conditional login button * updating frontend * snapshot: OBO flow works * auth login working e2e * cannot use env vars from frontend * add adls gen2 setup * more changes to prepdocs * fix auth + streaming * fixing up scripts * add view action to manageacl * Writing documentation * doc WIP * push auth config from server to client * updating docs, some minor code edits to be consistent * checkpoint * manual setup only for now * remove manual logging * remove optional print * typo * hosting on localhost for redirect uri * remove ms graph sdk * run black, ruff * dependency injection for AuthenticationHelper * encrypted token cache * more feedback * more feedback, port adlsgen2 to python * ruff, black * ruff, black don't change files i didn't write * fix manage acl script * update start to support codespaces * run black * manual test, github codespaces localhost still works * fixing prepdocs after manual test of azd up without auth * adding sh files; fixing script errors * debugging auth on codespaces * running through setup instructions * note about consent * change default scope * switch to unordered list * missing note * addressing feedback... * more feedback around * doc strings * formatting * feedback on group claims * switch to transitivememberof * readme feedback * refactor approach to use common filtering method * more feedback * refactoring * writing tests * tests * test adls gen2 prepdocs * fixing tests using env vars; adding adls gen2 tests * broken? * fixing tests * more tests * fixing CI errors * feedback * fix script * fix script * fix script * bicep deployment; add documentation for troubleshooting * lowercase true for env comparison * feedback * fix sh syntax errors * fixing syntax errors * Script fixes --------- Co-authored-by: Matt Gotteiner <magottei@microsoft.com>
1 parent d61c6ee commit 05bef47

54 files changed

Lines changed: 7243 additions & 1308 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

LoginAndAclSetup.md

Lines changed: 188 additions & 0 deletions
Large diffs are not rendered by default.

README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
- [Enabling optional features](#enabling-optional-features)
1818
- [Enabling Application Insights](#enabling-application-insights)
1919
- [Enabling authentication](#enabling-authentication)
20+
- [Enabling login and document level access control](#enabling-login-and-document-level-access-control)
2021
- [Using the app](#using-the-app)
2122
- [Running locally](#running-locally)
2223
- [Productionizing](#productionizing)
@@ -215,6 +216,10 @@ By default, the deployed Azure web app will have no authentication or access res
215216

216217
To then limit access to a specific set of users or groups, you can follow the steps from [Restrict your Azure AD app to a set of users](https://learn.microsoft.com/azure/active-directory/develop/howto-restrict-your-app-to-a-set-of-users) by changing "Assignment Required?" option under the Enterprise Application, and then assigning users/groups access. Users not granted explicit access will receive the error message -AADSTS50105: Your administrator has configured the application <app_name> to block users unless they are specifically granted ('assigned') access to the application.-
217218

219+
### Enabling login and document level access control
220+
221+
By default, the deployed Azure web app allows users to chat with all your indexed data. You can enable an optional login system using Azure Active Directory to restrict access to indexed data based on the logged in user. Enable the optional login and document level access control system by following [this guide](./LoginAndAclSetup.md).
222+
218223
## Running locally
219224

220225
You can only run locally **after** having successfully run the `azd up` command. If you haven't yet, follow the steps in [Azure deployment](#azure-deployment) above.

app/backend/app.py

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,15 @@
3030
from approaches.readdecomposeask import ReadDecomposeAsk
3131
from approaches.readretrieveread import ReadRetrieveReadApproach
3232
from approaches.retrievethenread import RetrieveThenReadApproach
33+
from core.authentication import AuthenticationHelper
3334

3435
CONFIG_OPENAI_TOKEN = "openai_token"
3536
CONFIG_CREDENTIAL = "azure_credential"
3637
CONFIG_ASK_APPROACHES = "ask_approaches"
3738
CONFIG_CHAT_APPROACHES = "chat_approaches"
3839
CONFIG_BLOB_CONTAINER_CLIENT = "blob_container_client"
40+
CONFIG_AUTH_CLIENT = "auth_client"
41+
CONFIG_SEARCH_CLIENT = "search_client"
3942

4043
bp = Blueprint("routes", __name__, static_folder="static")
4144

@@ -45,6 +48,13 @@ async def index():
4548
return await bp.send_static_file("index.html")
4649

4750

51+
# Empty page is recommended for login redirect to work.
52+
# See https://github.com/AzureAD/microsoft-authentication-library-for-js/blob/dev/lib/msal-browser/docs/initialization.md#redirecturi-considerations for more information
53+
@bp.route("/redirect")
54+
async def redirect():
55+
return ""
56+
57+
4858
@bp.route("/favicon.ico")
4959
async def favicon():
5060
return await bp.send_static_file("favicon.ico")
@@ -78,6 +88,8 @@ async def ask():
7888
if not request.is_json:
7989
return jsonify({"error": "request must be json"}), 415
8090
request_json = await request.get_json()
91+
auth_helper = current_app.config[CONFIG_AUTH_CLIENT]
92+
auth_claims = await auth_helper.get_auth_claims_if_enabled(request.headers)
8193
approach = request_json["approach"]
8294
try:
8395
impl = current_app.config[CONFIG_ASK_APPROACHES].get(approach)
@@ -86,7 +98,7 @@ async def ask():
8698
# Workaround for: https://github.com/openai/openai-python/issues/371
8799
async with aiohttp.ClientSession() as s:
88100
openai.aiosession.set(s)
89-
r = await impl.run(request_json["question"], request_json.get("overrides") or {})
101+
r = await impl.run(request_json["question"], request_json.get("overrides") or {}, auth_claims)
90102
return jsonify(r)
91103
except Exception as e:
92104
logging.exception("Exception in /ask")
@@ -98,6 +110,8 @@ async def chat():
98110
if not request.is_json:
99111
return jsonify({"error": "request must be json"}), 415
100112
request_json = await request.get_json()
113+
auth_helper = current_app.config[CONFIG_AUTH_CLIENT]
114+
auth_claims = await auth_helper.get_auth_claims_if_enabled(request.headers)
101115
approach = request_json["approach"]
102116
try:
103117
impl = current_app.config[CONFIG_CHAT_APPROACHES].get(approach)
@@ -106,7 +120,9 @@ async def chat():
106120
# Workaround for: https://github.com/openai/openai-python/issues/371
107121
async with aiohttp.ClientSession() as s:
108122
openai.aiosession.set(s)
109-
r = await impl.run_without_streaming(request_json["history"], request_json.get("overrides", {}))
123+
r = await impl.run_without_streaming(
124+
request_json["history"], request_json.get("overrides", {}), auth_claims
125+
)
110126
return jsonify(r)
111127
except Exception as e:
112128
logging.exception("Exception in /chat")
@@ -123,12 +139,16 @@ async def chat_stream():
123139
if not request.is_json:
124140
return jsonify({"error": "request must be json"}), 415
125141
request_json = await request.get_json()
142+
auth_helper = current_app.config[CONFIG_AUTH_CLIENT]
143+
auth_claims = await auth_helper.get_auth_claims_if_enabled(request.headers)
126144
approach = request_json["approach"]
127145
try:
128146
impl = current_app.config[CONFIG_CHAT_APPROACHES].get(approach)
129147
if not impl:
130148
return jsonify({"error": "unknown approach"}), 400
131-
response_generator = impl.run_with_streaming(request_json["history"], request_json.get("overrides", {}))
149+
response_generator = impl.run_with_streaming(
150+
request_json["history"], request_json.get("overrides", {}), auth_claims
151+
)
132152
response = await make_response(format_as_ndjson(response_generator))
133153
response.timeout = None # type: ignore
134154
return response
@@ -137,6 +157,13 @@ async def chat_stream():
137157
return jsonify({"error": str(e)}), 500
138158

139159

160+
# Send MSAL.js settings to the client UI
161+
@bp.route("/auth_setup", methods=["GET"])
162+
def auth_setup():
163+
auth_helper = current_app.config[CONFIG_AUTH_CLIENT]
164+
return jsonify(auth_helper.get_auth_setup_for_client())
165+
166+
140167
@bp.before_request
141168
async def ensure_openai_token():
142169
if openai.api_type != "azure_ad":
@@ -168,6 +195,12 @@ async def setup_clients():
168195
# Used only with non-Azure OpenAI deployments
169196
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
170197
OPENAI_ORGANIZATION = os.getenv("OPENAI_ORGANIZATION")
198+
AZURE_USE_AUTHENTICATION = os.getenv("AZURE_USE_AUTHENTICATION", "").lower() == "true"
199+
AZURE_SERVER_APP_ID = os.getenv("AZURE_SERVER_APP_ID")
200+
AZURE_SERVER_APP_SECRET = os.getenv("AZURE_SERVER_APP_SECRET")
201+
AZURE_CLIENT_APP_ID = os.getenv("AZURE_CLIENT_APP_ID")
202+
AZURE_TENANT_ID = os.getenv("AZURE_TENANT_ID")
203+
TOKEN_CACHE_PATH = os.getenv("TOKEN_CACHE_PATH")
171204

172205
KB_FIELDS_CONTENT = os.getenv("KB_FIELDS_CONTENT", "content")
173206
KB_FIELDS_SOURCEPAGE = os.getenv("KB_FIELDS_SOURCEPAGE", "sourcepage")
@@ -178,6 +211,16 @@ async def setup_clients():
178211
# If you encounter a blocking error during a DefaultAzureCredential resolution, you can exclude the problematic credential by using a parameter (ex. exclude_shared_token_cache_credential=True)
179212
azure_credential = DefaultAzureCredential(exclude_shared_token_cache_credential=True)
180213

214+
# Set up authentication helper
215+
auth_helper = AuthenticationHelper(
216+
use_authentication=AZURE_USE_AUTHENTICATION,
217+
server_app_id=AZURE_SERVER_APP_ID,
218+
server_app_secret=AZURE_SERVER_APP_SECRET,
219+
client_app_id=AZURE_CLIENT_APP_ID,
220+
tenant_id=AZURE_TENANT_ID,
221+
token_cache_path=TOKEN_CACHE_PATH,
222+
)
223+
181224
# Set up clients for Cognitive Search and Storage
182225
search_client = SearchClient(
183226
endpoint=f"https://{AZURE_SEARCH_SERVICE}.search.windows.net",
@@ -204,7 +247,9 @@ async def setup_clients():
204247
openai.organization = OPENAI_ORGANIZATION
205248

206249
current_app.config[CONFIG_CREDENTIAL] = azure_credential
250+
current_app.config[CONFIG_SEARCH_CLIENT] = search_client
207251
current_app.config[CONFIG_BLOB_CONTAINER_CLIENT] = blob_container_client
252+
current_app.config[CONFIG_AUTH_CLIENT] = auth_helper
208253

209254
# Various approaches to integrate GPT and external knowledge, most applications will use a single one of these patterns
210255
# or some derivative, here we include several for exploration purposes

app/backend/approaches/approach.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,22 @@
11
from abc import ABC, abstractmethod
22
from typing import Any
33

4+
from core.authentication import AuthenticationHelper
45

5-
class AskApproach(ABC):
6+
7+
class Approach(ABC):
8+
def build_filter(self, overrides: dict[str, Any], auth_claims: dict[str, Any]) -> str:
9+
exclude_category = overrides.get("exclude_category") or None
10+
security_filter = AuthenticationHelper.build_security_filters(overrides, auth_claims)
11+
filters = []
12+
if exclude_category:
13+
filters.append("category ne '{}'".format(exclude_category.replace("'", "''")))
14+
if security_filter:
15+
filters.append(security_filter)
16+
return None if len(filters) == 0 else " and ".join(filters)
17+
18+
19+
class AskApproach(Approach):
620
@abstractmethod
7-
async def run(self, q: str, overrides: dict[str, Any]) -> dict[str, Any]:
21+
async def run(self, q: str, overrides: dict[str, Any], auth_claims: dict[str, Any]) -> dict[str, Any]:
822
...

app/backend/approaches/chatreadretrieveread.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,13 @@
55
from azure.search.documents.aio import SearchClient
66
from azure.search.documents.models import QueryType
77

8+
from approaches.approach import Approach
89
from core.messagebuilder import MessageBuilder
910
from core.modelhelper import get_token_limit
1011
from text import nonewlines
1112

1213

13-
class ChatReadRetrieveReadApproach:
14+
class ChatReadRetrieveReadApproach(Approach):
1415
# Chat roles
1516
SYSTEM = "system"
1617
USER = "user"
@@ -73,14 +74,17 @@ def __init__(
7374
self.chatgpt_token_limit = get_token_limit(chatgpt_model)
7475

7576
async def run_until_final_call(
76-
self, history: list[dict[str, str]], overrides: dict[str, Any], should_stream: bool = False
77+
self,
78+
history: list[dict[str, str]],
79+
overrides: dict[str, Any],
80+
auth_claims: dict[str, Any],
81+
should_stream: bool = False,
7782
) -> tuple:
7883
has_text = overrides.get("retrieval_mode") in ["text", "hybrid", None]
7984
has_vector = overrides.get("retrieval_mode") in ["vectors", "hybrid", None]
8085
use_semantic_captions = True if overrides.get("semantic_captions") and has_text else False
81-
top = overrides.get("top") or 3
82-
exclude_category = overrides.get("exclude_category") or None
83-
filter = "category ne '{}'".format(exclude_category.replace("'", "''")) if exclude_category else None
86+
top = overrides.get("top", 3)
87+
filter = self.build_filter(overrides, auth_claims)
8488

8589
user_query_request = "Generate search query for: " + history[-1]["user"]
8690

@@ -195,10 +199,8 @@ async def run_until_final_call(
195199
system_message,
196200
self.chatgpt_model,
197201
history,
198-
# Model does not handle lengthy system messages well.
199-
# Moved sources to latest user conversation to solve follow up questions prompt.
200202
history[-1]["user"] + "\n\nSources:\n" + content,
201-
max_tokens=self.chatgpt_token_limit,
203+
max_tokens=self.chatgpt_token_limit, # Model does not handle lengthy system messages well. Moving sources to latest user conversation to solve follow up questions prompt.
202204
)
203205
msg_to_display = "\n\n".join([str(message) for message in messages])
204206

@@ -219,17 +221,23 @@ async def run_until_final_call(
219221
)
220222
return (extra_info, chat_coroutine)
221223

222-
async def run_without_streaming(self, history: list[dict[str, str]], overrides: dict[str, Any]) -> dict[str, Any]:
223-
extra_info, chat_coroutine = await self.run_until_final_call(history, overrides, should_stream=False)
224+
async def run_without_streaming(
225+
self, history: list[dict[str, str]], overrides: dict[str, Any], auth_claims: dict[str, Any]
226+
) -> dict[str, Any]:
227+
extra_info, chat_coroutine = await self.run_until_final_call(
228+
history, overrides, auth_claims, should_stream=False
229+
)
224230
chat_resp = await chat_coroutine
225231
chat_content = chat_resp.choices[0].message.content
226232
extra_info["answer"] = chat_content
227233
return extra_info
228234

229235
async def run_with_streaming(
230-
self, history: list[dict[str, str]], overrides: dict[str, Any]
236+
self, history: list[dict[str, str]], overrides: dict[str, Any], auth_claims: dict[str, Any]
231237
) -> AsyncGenerator[dict, None]:
232-
extra_info, chat_coroutine = await self.run_until_final_call(history, overrides, should_stream=True)
238+
extra_info, chat_coroutine = await self.run_until_final_call(
239+
history, overrides, auth_claims, should_stream=True
240+
)
233241
yield extra_info
234242
async for event in await chat_coroutine:
235243
# "2023-07-01-preview" API version has a bug where first response has empty choices
@@ -247,8 +255,7 @@ def get_messages_from_history(
247255
) -> list:
248256
message_builder = MessageBuilder(system_prompt, model_id)
249257

250-
# Add examples to show the chat what responses we want.
251-
# It will try to mimic any responses and make sure they match the rules laid out in the system message.
258+
# Add examples to show the chat what responses we want. It will try to mimic any responses and make sure they match the rules laid out in the system message.
252259
for shot in few_shots:
253260
message_builder.append_message(shot.get("role"), shot.get("content"))
254261

app/backend/approaches/readdecomposeask.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,14 @@ def __init__(
3737
self.content_field = content_field
3838
self.openai_host = openai_host
3939

40-
async def search(self, query_text: str, overrides: dict[str, Any]) -> tuple[list[str], str]:
40+
async def search(
41+
self, query_text: str, overrides: dict[str, Any], auth_claims: dict[str, Any]
42+
) -> tuple[list[str], str]:
4143
has_text = overrides.get("retrieval_mode") in ["text", "hybrid", None]
4244
has_vector = overrides.get("retrieval_mode") in ["vectors", "hybrid", None]
4345
use_semantic_captions = True if overrides.get("semantic_captions") and has_text else False
44-
top = overrides.get("top") or 3
45-
exclude_category = overrides.get("exclude_category") or None
46-
filter = "category ne '{}'".format(exclude_category.replace("'", "''")) if exclude_category else None
46+
top = overrides.get("top", 3)
47+
filter = self.build_filter(overrides, auth_claims)
4748

4849
# If retrieval mode includes vectors, compute an embedding for the query
4950
if has_vector:
@@ -109,12 +110,12 @@ async def lookup(self, q: str) -> Optional[str]:
109110
return "\n".join([d["content"] async for d in r])
110111
return None
111112

112-
async def run(self, q: str, overrides: dict[str, Any]) -> dict[str, Any]:
113+
async def run(self, q: str, overrides: dict[str, Any], auth_claims: dict[str, Any]) -> dict[str, Any]:
113114
search_results = None
114115

115116
async def search_and_store(q: str) -> Any:
116117
nonlocal search_results
117-
search_results, content = await self.search(q, overrides)
118+
search_results, content = await self.search(q, overrides, auth_claims)
118119
return content
119120

120121
# Use to capture thought process during iterations

app/backend/approaches/readretrieveread.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,13 +68,12 @@ def __init__(
6868
self.content_field = content_field
6969
self.openai_host = openai_host
7070

71-
async def retrieve(self, query_text: str, overrides: dict[str, Any]) -> Any:
71+
async def retrieve(self, query_text: str, overrides: dict[str, Any], auth_claims: dict[str, Any]) -> Any:
7272
has_text = overrides.get("retrieval_mode") in ["text", "hybrid", None]
7373
has_vector = overrides.get("retrieval_mode") in ["vectors", "hybrid", None]
7474
use_semantic_captions = True if overrides.get("semantic_captions") and has_text else False
75-
top = overrides.get("top") or 3
76-
exclude_category = overrides.get("exclude_category") or None
77-
filter = "category ne '{}'".format(exclude_category.replace("'", "''")) if exclude_category else None
75+
top = overrides.get("top", 3)
76+
filter = self.build_filter(overrides, auth_claims)
7877

7978
# If retrieval mode includes vectors, compute an embedding for the query
8079
if has_vector:
@@ -122,12 +121,12 @@ async def retrieve(self, query_text: str, overrides: dict[str, Any]) -> Any:
122121
content = "\n".join(results)
123122
return results, content
124123

125-
async def run(self, q: str, overrides: dict[str, Any]) -> dict[str, Any]:
124+
async def run(self, q: str, overrides: dict[str, Any], auth_claims: dict[str, Any]) -> dict[str, Any]:
126125
retrieve_results = None
127126

128127
async def retrieve_and_store(q: str) -> Any:
129128
nonlocal retrieve_results
130-
retrieve_results, content = await self.retrieve(q, overrides)
129+
retrieve_results, content = await self.retrieve(q, overrides, auth_claims)
131130
return content
132131

133132
# Use to capture thought process during iterations

app/backend/approaches/retrievethenread.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,13 +57,12 @@ def __init__(
5757
self.sourcepage_field = sourcepage_field
5858
self.content_field = content_field
5959

60-
async def run(self, q: str, overrides: dict[str, Any]) -> dict[str, Any]:
60+
async def run(self, q: str, overrides: dict[str, Any], auth_claims: dict[str, Any]) -> dict[str, Any]:
6161
has_text = overrides.get("retrieval_mode") in ["text", "hybrid", None]
6262
has_vector = overrides.get("retrieval_mode") in ["vectors", "hybrid", None]
6363
use_semantic_captions = True if overrides.get("semantic_captions") and has_text else False
64-
top = overrides.get("top") or 3
65-
exclude_category = overrides.get("exclude_category") or None
66-
filter = "category ne '{}'".format(exclude_category.replace("'", "''")) if exclude_category else None
64+
top = overrides.get("top", 3)
65+
filter = self.build_filter(overrides, auth_claims)
6766

6867
# If retrieval mode includes vectors, compute an embedding for the query
6968
if has_vector:

0 commit comments

Comments
 (0)