@@ -70,6 +70,14 @@ def get_schema(self, request=None, public=False):
7070 """
7171 self ._initialise_endpoints ()
7272 components_schemas = {}
73+ security_schemes_schemas = {}
74+ root_security_requirements = []
75+
76+ if api_settings .DEFAULT_AUTHENTICATION_CLASSES :
77+ for auth_class in api_settings .DEFAULT_AUTHENTICATION_CLASSES :
78+ req = auth_class .openapi_security_requirement (None , None )
79+ if req :
80+ root_security_requirements += req
7381
7482 # Iterate endpoints generating per method path operations.
7583 paths = {}
@@ -80,6 +88,7 @@ def get_schema(self, request=None, public=False):
8088
8189 operation = view .schema .get_operation (path , method )
8290 components = view .schema .get_components (path , method )
91+
8392 for k in components .keys ():
8493 if k not in components_schemas :
8594 continue
@@ -89,6 +98,16 @@ def get_schema(self, request=None, public=False):
8998
9099 components_schemas .update (components )
91100
101+ security_schemes = view .schema .get_security_schemes (path , method )
102+ for k in security_schemes .keys ():
103+ if k not in security_schemes_schemas :
104+ continue
105+ if security_schemes_schemas [k ] == security_schemes [k ]:
106+ continue
107+ warnings .warn ('Security scheme component "{}" has been overriden with a different '
108+ 'value.' .format (k ))
109+ security_schemes_schemas .update (security_schemes )
110+
92111 # Normalise path for any provided mount url.
93112 if path .startswith ('/' ):
94113 path = path [1 :]
@@ -111,6 +130,14 @@ def get_schema(self, request=None, public=False):
111130 'schemas' : components_schemas
112131 }
113132
133+ if len (security_schemes_schemas ) > 0 :
134+ if 'components' not in schema :
135+ schema ['components' ] = {}
136+ schema ['components' ]['securitySchemes' ] = security_schemes_schemas
137+
138+ if len (root_security_requirements ) > 0 :
139+ schema ['security' ] = root_security_requirements
140+
114141 return schema
115142
116143# View Inspectors
@@ -146,6 +173,9 @@ def get_operation(self, path, method):
146173
147174 operation ['operationId' ] = self .get_operation_id (path , method )
148175 operation ['description' ] = self .get_description (path , method )
176+ security = self .get_security_requirements (path , method )
177+ if security is not None :
178+ operation ['security' ] = security
149179
150180 parameters = []
151181 parameters += self .get_path_parameters (path , method )
@@ -692,6 +722,34 @@ def get_tags(self, path, method):
692722
693723 return [path .split ('/' )[0 ].replace ('_' , '-' )]
694724
725+ def get_security_schemes (self , path , method ):
726+ """
727+ Get components.schemas.securitySchemes required by this path.
728+ returns dict of securitySchemes.
729+ """
730+ schemes = {}
731+ for auth_class in self .view .authentication_classes :
732+ if hasattr (auth_class , 'openapi_security_scheme' ):
733+ schemes .update (auth_class .openapi_security_scheme ())
734+ return schemes
735+
736+ def get_security_requirements (self , path , method ):
737+ """
738+ Get Security Requirement Object list for this operation.
739+ Returns a list of security requirement objects based on the view's authentication classes
740+ unless this view's authentication classes are the same as the root-level defaults.
741+ """
742+ # references the securityScheme names described above in get_security_schemes()
743+ security = []
744+ if self .view .authentication_classes == api_settings .DEFAULT_AUTHENTICATION_CLASSES :
745+ return None
746+ for auth_class in self .view .authentication_classes :
747+ if hasattr (auth_class , 'openapi_security_requirement' ):
748+ req = auth_class .openapi_security_requirement (self .view , method )
749+ if req :
750+ security += req
751+ return security
752+
695753 def _get_path_parameters (self , path , method ):
696754 warnings .warn (
697755 "Method `_get_path_parameters()` has been renamed to `get_path_parameters()`. "
0 commit comments