11import gradio as gr
2+ import torch
3+ from transformers import RobertaTokenizer , RobertaForSequenceClassification
4+ import pandas as pd
5+ import plotly .graph_objects as go
6+ import plotly .express as px
7+ import json
8+ import numpy as np
9+ # from functools import lru_cache
210
3- def greet (name ):
4- return "Hello " + name + "!!"
11+ # Cache the model loading
12+ @gr .cache
13+ # @lru_cache(maxsize=1)
14+ def load_model ():
15+ model_path = "MMADS/MoralFoundationsClassifier"
16+ model = RobertaForSequenceClassification .from_pretrained (model_path )
17+ tokenizer = RobertaTokenizer .from_pretrained (model_path )
18+
19+ # Load label names
20+ label_names = [
21+ "care_virtue" , "care_vice" ,
22+ "fairness_virtue" , "fairness_vice" ,
23+ "loyalty_virtue" , "loyalty_vice" ,
24+ "authority_virtue" , "authority_vice" ,
25+ "sanctity_virtue" , "sanctity_vice"
26+ ]
27+
28+ return model , tokenizer , label_names
529
6- demo = gr .Interface (fn = greet , inputs = "text" , outputs = "text" )
7- demo .launch ()
30+ def predict_batch (texts , model , tokenizer , label_names ):
31+ """Process texts in batch for efficiency"""
32+ # Tokenize all texts at once
33+ inputs = tokenizer (texts , padding = True , truncation = True , max_length = 512 , return_tensors = "pt" )
34+
35+ # Get predictions
36+ with torch .no_grad ():
37+ outputs = model (** inputs )
38+ predictions = torch .sigmoid (outputs .logits )
39+
40+ # Convert to numpy array
41+ predictions = predictions .numpy ()
42+
43+ # Create results for each text
44+ results = []
45+ for i , text in enumerate (texts ):
46+ scores = {label : float (predictions [i , j ]) for j , label in enumerate (label_names )}
47+ results .append ({
48+ 'text' : text ,
49+ 'scores' : scores
50+ })
51+
52+ return results
53+
54+ def create_visualization (results ):
55+ """Create visualization for moral foundation scores"""
56+ if not results :
57+ return None
58+
59+ # Aggregate scores across all texts
60+ all_scores = {}
61+ for label in results [0 ]['scores' ].keys ():
62+ all_scores [label ] = [r ['scores' ][label ] for r in results ]
63+
64+ # Create grouped bar chart
65+ foundations = ['care' , 'fairness' , 'loyalty' , 'authority' , 'sanctity' ]
66+ virtues = []
67+ vices = []
68+
69+ for foundation in foundations :
70+ virtue_scores = all_scores [f"{ foundation } _virtue" ]
71+ vice_scores = all_scores [f"{ foundation } _vice" ]
72+ virtues .append (np .mean (virtue_scores ))
73+ vices .append (np .mean (vice_scores ))
74+
75+ fig = go .Figure ()
76+
77+ fig .add_trace (go .Bar (
78+ name = 'Virtues' ,
79+ x = foundations ,
80+ y = virtues ,
81+ marker_color = 'lightgreen'
82+ ))
83+
84+ fig .add_trace (go .Bar (
85+ name = 'Vices' ,
86+ x = foundations ,
87+ y = vices ,
88+ marker_color = 'lightcoral'
89+ ))
90+
91+ fig .update_layout (
92+ title = "Average Moral Foundation Scores" ,
93+ xaxis_title = "Moral Foundations" ,
94+ yaxis_title = "Average Score" ,
95+ barmode = 'group' ,
96+ yaxis = dict (range = [0 , 1 ]),
97+ template = "plotly_white"
98+ )
99+
100+ return fig
101+
102+ def create_heatmap (results ):
103+ """Create heatmap visualization"""
104+ if not results :
105+ return None
106+
107+ # Create matrix for heatmap
108+ texts = [r ['text' ][:50 ] + "..." if len (r ['text' ]) > 50 else r ['text' ] for r in results ]
109+ labels = list (results [0 ]['scores' ].keys ())
110+
111+ matrix = []
112+ for result in results :
113+ matrix .append ([result ['scores' ][label ] for label in labels ])
114+
115+ fig = px .imshow (
116+ matrix ,
117+ labels = dict (x = "Moral Foundations" , y = "Texts" , color = "Score" ),
118+ x = labels ,
119+ y = texts ,
120+ aspect = "auto" ,
121+ color_continuous_scale = "RdBu_r"
122+ )
123+
124+ fig .update_layout (
125+ title = "Moral Foundation Scores Heatmap" ,
126+ height = max (400 , len (texts ) * 30 )
127+ )
128+
129+ return fig
130+
131+ def process_text (text ):
132+ """Process single text input"""
133+ model , tokenizer , label_names = load_model ()
134+ results = predict_batch ([text ], model , tokenizer , label_names )
135+
136+ # Format output
137+ scores_text = "**Moral Foundation Scores:**\n \n "
138+ for label , score in results [0 ]['scores' ].items ():
139+ foundation = label .replace ('_' , ' ' ).title ()
140+ scores_text += f"{ foundation } : { score :.4f} \n "
141+
142+ # Create visualizations
143+ bar_chart = create_visualization (results )
144+
145+ return scores_text , bar_chart
146+
147+ def process_csv (file ):
148+ """Process CSV file with multiple texts"""
149+ if file is None :
150+ return "Please upload a CSV file" , None , None
151+
152+ try :
153+ # Read CSV
154+ df = pd .read_csv (file .name )
155+
156+ if 'text' not in df .columns :
157+ return "Error: CSV must contain a 'text' column" , None , None
158+
159+ texts = df ['text' ].tolist ()
160+
161+ # Load model and process in batches
162+ model , tokenizer , label_names = load_model ()
163+
164+ # Process in batches of 32
165+ batch_size = 32
166+ all_results = []
167+
168+ for i in range (0 , len (texts ), batch_size ):
169+ batch_texts = texts [i :i + batch_size ]
170+ batch_results = predict_batch (batch_texts , model , tokenizer , label_names )
171+ all_results .extend (batch_results )
172+
173+ # Create summary
174+ summary = f"**Processed { len (texts )} texts**\n \n "
175+
176+ # Create visualizations
177+ bar_chart = create_visualization (all_results )
178+ heatmap = create_heatmap (all_results [:20 ]) # Limit heatmap to first 20 texts
179+
180+ # Create downloadable results
181+ results_df = pd .DataFrame ([
182+ {
183+ 'text' : r ['text' ],
184+ ** r ['scores' ]
185+ } for r in all_results
186+ ])
187+
188+ results_df .to_csv ('results.csv' , index = False )
189+
190+ return summary + "Results saved to results.csv" , bar_chart , heatmap
191+
192+ except Exception as e :
193+ return f"Error processing CSV: { str (e )} " , None , None
194+
195+ # Create example texts
196+ example_texts = [
197+ "We must protect the vulnerable and care for those who cannot care for themselves." ,
198+ "Everyone deserves equal treatment under the law, regardless of their background." ,
199+ "Betraying your country is one of the worst things a person can do." ,
200+ "We should respect our elders and follow traditional values." ,
201+ "Some things are sacred and should not be violated or mocked."
202+ ]
203+
204+ # Create Gradio interface
205+ with gr .Blocks (title = "Moral Foundations Classifier" ) as demo :
206+ gr .Markdown ("""
207+ # Moral Foundations Classifier
208+
209+ This app analyzes text for moral foundations based on Moral Foundations Theory.
210+ It identifies five moral foundations (each with virtue and vice dimensions):
211+ - **Care/Harm**: Compassion and protection vs. harm
212+ - **Fairness/Cheating**: Justice and equality vs. cheating
213+ - **Loyalty/Betrayal**: Group loyalty vs. betrayal
214+ - **Authority/Subversion**: Respect for authority vs. subversion
215+ - **Sanctity/Degradation**: Purity and sanctity vs. degradation
216+ """ )
217+
218+ with gr .Tab ("Single Text Analysis" ):
219+ text_input = gr .Textbox (
220+ label = "Enter text to analyze" ,
221+ placeholder = "Type or paste your text here..." ,
222+ lines = 5
223+ )
224+
225+ gr .Examples (
226+ examples = example_texts ,
227+ inputs = text_input ,
228+ label = "Example Texts"
229+ )
230+
231+ analyze_btn = gr .Button ("Analyze Text" , variant = "primary" )
232+
233+ with gr .Row ():
234+ scores_output = gr .Markdown (label = "Scores" )
235+ chart_output = gr .Plot (label = "Visualization" )
236+
237+ analyze_btn .click (
238+ fn = process_text ,
239+ inputs = text_input ,
240+ outputs = [scores_output , chart_output ]
241+ )
242+
243+ with gr .Tab ("Batch Analysis (CSV)" ):
244+ gr .Markdown ("""
245+ Upload a CSV file with a 'text' column containing the texts to analyze.
246+ The app will process all texts and provide aggregate visualizations.
247+ """ )
248+
249+ csv_input = gr .File (
250+ label = "Upload CSV file" ,
251+ file_types = [".csv" ]
252+ )
253+
254+ process_btn = gr .Button ("Process CSV" , variant = "primary" )
255+
256+ summary_output = gr .Markdown (label = "Summary" )
257+
258+ with gr .Row ():
259+ bar_output = gr .Plot (label = "Average Scores" )
260+ heatmap_output = gr .Plot (label = "Scores Heatmap" )
261+
262+ process_btn .click (
263+ fn = process_csv ,
264+ inputs = csv_input ,
265+ outputs = [summary_output , bar_output , heatmap_output ]
266+ )
267+
268+ gr .Markdown ("""
269+ ---
270+ Based on the [MoralFoundationsClassifier](https://huggingface.co/MMADS/MoralFoundationsClassifier) by M. Murat Ardag
271+ """ )
272+
273+ if __name__ == "__main__" :
274+ demo .launch ()
0 commit comments