Skip to content

Commit abb7957

Browse files
committed
added the actual app; updated the requirements.txt
1 parent 02b59bf commit abb7957

File tree

2 files changed

+272
-5
lines changed

2 files changed

+272
-5
lines changed

app.py

Lines changed: 271 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,274 @@
11
import gradio as gr
2+
import torch
3+
from transformers import RobertaTokenizer, RobertaForSequenceClassification
4+
import pandas as pd
5+
import plotly.graph_objects as go
6+
import plotly.express as px
7+
import json
8+
import numpy as np
9+
# from functools import lru_cache
210

3-
def greet(name):
4-
return "Hello " + name + "!!"
11+
# Cache the model loading
12+
@gr.cache
13+
# @lru_cache(maxsize=1)
14+
def load_model():
15+
model_path = "MMADS/MoralFoundationsClassifier"
16+
model = RobertaForSequenceClassification.from_pretrained(model_path)
17+
tokenizer = RobertaTokenizer.from_pretrained(model_path)
18+
19+
# Load label names
20+
label_names = [
21+
"care_virtue", "care_vice",
22+
"fairness_virtue", "fairness_vice",
23+
"loyalty_virtue", "loyalty_vice",
24+
"authority_virtue", "authority_vice",
25+
"sanctity_virtue", "sanctity_vice"
26+
]
27+
28+
return model, tokenizer, label_names
529

6-
demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7-
demo.launch()
30+
def predict_batch(texts, model, tokenizer, label_names):
31+
"""Process texts in batch for efficiency"""
32+
# Tokenize all texts at once
33+
inputs = tokenizer(texts, padding=True, truncation=True, max_length=512, return_tensors="pt")
34+
35+
# Get predictions
36+
with torch.no_grad():
37+
outputs = model(**inputs)
38+
predictions = torch.sigmoid(outputs.logits)
39+
40+
# Convert to numpy array
41+
predictions = predictions.numpy()
42+
43+
# Create results for each text
44+
results = []
45+
for i, text in enumerate(texts):
46+
scores = {label: float(predictions[i, j]) for j, label in enumerate(label_names)}
47+
results.append({
48+
'text': text,
49+
'scores': scores
50+
})
51+
52+
return results
53+
54+
def create_visualization(results):
55+
"""Create visualization for moral foundation scores"""
56+
if not results:
57+
return None
58+
59+
# Aggregate scores across all texts
60+
all_scores = {}
61+
for label in results[0]['scores'].keys():
62+
all_scores[label] = [r['scores'][label] for r in results]
63+
64+
# Create grouped bar chart
65+
foundations = ['care', 'fairness', 'loyalty', 'authority', 'sanctity']
66+
virtues = []
67+
vices = []
68+
69+
for foundation in foundations:
70+
virtue_scores = all_scores[f"{foundation}_virtue"]
71+
vice_scores = all_scores[f"{foundation}_vice"]
72+
virtues.append(np.mean(virtue_scores))
73+
vices.append(np.mean(vice_scores))
74+
75+
fig = go.Figure()
76+
77+
fig.add_trace(go.Bar(
78+
name='Virtues',
79+
x=foundations,
80+
y=virtues,
81+
marker_color='lightgreen'
82+
))
83+
84+
fig.add_trace(go.Bar(
85+
name='Vices',
86+
x=foundations,
87+
y=vices,
88+
marker_color='lightcoral'
89+
))
90+
91+
fig.update_layout(
92+
title="Average Moral Foundation Scores",
93+
xaxis_title="Moral Foundations",
94+
yaxis_title="Average Score",
95+
barmode='group',
96+
yaxis=dict(range=[0, 1]),
97+
template="plotly_white"
98+
)
99+
100+
return fig
101+
102+
def create_heatmap(results):
103+
"""Create heatmap visualization"""
104+
if not results:
105+
return None
106+
107+
# Create matrix for heatmap
108+
texts = [r['text'][:50] + "..." if len(r['text']) > 50 else r['text'] for r in results]
109+
labels = list(results[0]['scores'].keys())
110+
111+
matrix = []
112+
for result in results:
113+
matrix.append([result['scores'][label] for label in labels])
114+
115+
fig = px.imshow(
116+
matrix,
117+
labels=dict(x="Moral Foundations", y="Texts", color="Score"),
118+
x=labels,
119+
y=texts,
120+
aspect="auto",
121+
color_continuous_scale="RdBu_r"
122+
)
123+
124+
fig.update_layout(
125+
title="Moral Foundation Scores Heatmap",
126+
height=max(400, len(texts) * 30)
127+
)
128+
129+
return fig
130+
131+
def process_text(text):
132+
"""Process single text input"""
133+
model, tokenizer, label_names = load_model()
134+
results = predict_batch([text], model, tokenizer, label_names)
135+
136+
# Format output
137+
scores_text = "**Moral Foundation Scores:**\n\n"
138+
for label, score in results[0]['scores'].items():
139+
foundation = label.replace('_', ' ').title()
140+
scores_text += f"{foundation}: {score:.4f}\n"
141+
142+
# Create visualizations
143+
bar_chart = create_visualization(results)
144+
145+
return scores_text, bar_chart
146+
147+
def process_csv(file):
148+
"""Process CSV file with multiple texts"""
149+
if file is None:
150+
return "Please upload a CSV file", None, None
151+
152+
try:
153+
# Read CSV
154+
df = pd.read_csv(file.name)
155+
156+
if 'text' not in df.columns:
157+
return "Error: CSV must contain a 'text' column", None, None
158+
159+
texts = df['text'].tolist()
160+
161+
# Load model and process in batches
162+
model, tokenizer, label_names = load_model()
163+
164+
# Process in batches of 32
165+
batch_size = 32
166+
all_results = []
167+
168+
for i in range(0, len(texts), batch_size):
169+
batch_texts = texts[i:i+batch_size]
170+
batch_results = predict_batch(batch_texts, model, tokenizer, label_names)
171+
all_results.extend(batch_results)
172+
173+
# Create summary
174+
summary = f"**Processed {len(texts)} texts**\n\n"
175+
176+
# Create visualizations
177+
bar_chart = create_visualization(all_results)
178+
heatmap = create_heatmap(all_results[:20]) # Limit heatmap to first 20 texts
179+
180+
# Create downloadable results
181+
results_df = pd.DataFrame([
182+
{
183+
'text': r['text'],
184+
**r['scores']
185+
} for r in all_results
186+
])
187+
188+
results_df.to_csv('results.csv', index=False)
189+
190+
return summary + "Results saved to results.csv", bar_chart, heatmap
191+
192+
except Exception as e:
193+
return f"Error processing CSV: {str(e)}", None, None
194+
195+
# Create example texts
196+
example_texts = [
197+
"We must protect the vulnerable and care for those who cannot care for themselves.",
198+
"Everyone deserves equal treatment under the law, regardless of their background.",
199+
"Betraying your country is one of the worst things a person can do.",
200+
"We should respect our elders and follow traditional values.",
201+
"Some things are sacred and should not be violated or mocked."
202+
]
203+
204+
# Create Gradio interface
205+
with gr.Blocks(title="Moral Foundations Classifier") as demo:
206+
gr.Markdown("""
207+
# Moral Foundations Classifier
208+
209+
This app analyzes text for moral foundations based on Moral Foundations Theory.
210+
It identifies five moral foundations (each with virtue and vice dimensions):
211+
- **Care/Harm**: Compassion and protection vs. harm
212+
- **Fairness/Cheating**: Justice and equality vs. cheating
213+
- **Loyalty/Betrayal**: Group loyalty vs. betrayal
214+
- **Authority/Subversion**: Respect for authority vs. subversion
215+
- **Sanctity/Degradation**: Purity and sanctity vs. degradation
216+
""")
217+
218+
with gr.Tab("Single Text Analysis"):
219+
text_input = gr.Textbox(
220+
label="Enter text to analyze",
221+
placeholder="Type or paste your text here...",
222+
lines=5
223+
)
224+
225+
gr.Examples(
226+
examples=example_texts,
227+
inputs=text_input,
228+
label="Example Texts"
229+
)
230+
231+
analyze_btn = gr.Button("Analyze Text", variant="primary")
232+
233+
with gr.Row():
234+
scores_output = gr.Markdown(label="Scores")
235+
chart_output = gr.Plot(label="Visualization")
236+
237+
analyze_btn.click(
238+
fn=process_text,
239+
inputs=text_input,
240+
outputs=[scores_output, chart_output]
241+
)
242+
243+
with gr.Tab("Batch Analysis (CSV)"):
244+
gr.Markdown("""
245+
Upload a CSV file with a 'text' column containing the texts to analyze.
246+
The app will process all texts and provide aggregate visualizations.
247+
""")
248+
249+
csv_input = gr.File(
250+
label="Upload CSV file",
251+
file_types=[".csv"]
252+
)
253+
254+
process_btn = gr.Button("Process CSV", variant="primary")
255+
256+
summary_output = gr.Markdown(label="Summary")
257+
258+
with gr.Row():
259+
bar_output = gr.Plot(label="Average Scores")
260+
heatmap_output = gr.Plot(label="Scores Heatmap")
261+
262+
process_btn.click(
263+
fn=process_csv,
264+
inputs=csv_input,
265+
outputs=[summary_output, bar_output, heatmap_output]
266+
)
267+
268+
gr.Markdown("""
269+
---
270+
Based on the [MoralFoundationsClassifier](https://huggingface.co/MMADS/MoralFoundationsClassifier) by M. Murat Ardag
271+
""")
272+
273+
if __name__ == "__main__":
274+
demo.launch()

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
gradio==4.19.2
1+
gradio==4.44.1
22
transformers==4.38.0
33
torch==2.2.0
44
pandas==2.2.0

0 commit comments

Comments
 (0)