Skip to content

Commit d4622cb

Browse files
committed
fixed the csv results issue
1 parent 0730894 commit d4622cb

File tree

1 file changed

+29
-9
lines changed

1 file changed

+29
-9
lines changed

app.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from functools import lru_cache
1010

1111
# Cache the model loading
12-
# @gr.cache
1312
@lru_cache(maxsize=1)
1413
def load_model():
1514
model_path = "MMADS/MoralFoundationsClassifier"
@@ -144,34 +143,48 @@ def process_text(text):
144143

145144
return scores_text, bar_chart
146145

147-
def process_csv(file):
146+
def process_csv(file, progress=gr.Progress()):
148147
"""Process CSV file with multiple texts"""
149148
if file is None:
150-
return "Please upload a CSV file", None, None
149+
return "Please upload a CSV file", None, None, None
151150

152151
try:
153152
# Read CSV
154153
df = pd.read_csv(file.name)
155154

156155
if 'text' not in df.columns:
157-
return "Error: CSV must contain a 'text' column", None, None
156+
return "Error: CSV must contain a 'text' column", None, None, None
158157

159158
texts = df['text'].tolist()
160159

161160
# Load model and process in batches
161+
progress(0, desc="Loading model...")
162162
model, tokenizer, label_names = load_model()
163163

164164
# Process in batches of 32
165165
batch_size = 32
166166
all_results = []
167+
total_batches = (len(texts) + batch_size - 1) // batch_size
167168

168169
for i in range(0, len(texts), batch_size):
170+
batch_num = i // batch_size + 1
171+
progress(batch_num / total_batches, desc=f"Processing batch {batch_num}/{total_batches}")
172+
169173
batch_texts = texts[i:i+batch_size]
170174
batch_results = predict_batch(batch_texts, model, tokenizer, label_names)
171175
all_results.extend(batch_results)
172176

177+
progress(0.9, desc="Creating visualizations...")
178+
173179
# Create summary
174180
summary = f"**Processed {len(texts)} texts**\n\n"
181+
summary += "**Average Scores Across All Texts:**\n\n"
182+
183+
# Calculate average scores
184+
avg_scores = {}
185+
for label in label_names:
186+
avg_scores[label] = np.mean([r['scores'][label] for r in all_results])
187+
summary += f"{label.replace('_', ' ').title()}: {avg_scores[label]:.4f}\n"
175188

176189
# Create visualizations
177190
bar_chart = create_visualization(all_results)
@@ -185,12 +198,14 @@ def process_csv(file):
185198
} for r in all_results
186199
])
187200

188-
results_df.to_csv('results.csv', index=False)
201+
# Save to a temporary file and return the path
202+
output_path = "results.csv"
203+
results_df.to_csv(output_path, index=False)
189204

190-
return summary + "Results saved to results.csv", bar_chart, heatmap
205+
return summary, bar_chart, heatmap, output_path
191206

192207
except Exception as e:
193-
return f"Error processing CSV: {str(e)}", None, None
208+
return f"Error processing CSV: {str(e)}", None, None, None
194209

195210
# Create example texts
196211
example_texts = [
@@ -244,6 +259,8 @@ def process_csv(file):
244259
gr.Markdown("""
245260
Upload a CSV file with a 'text' column containing the texts to analyze.
246261
The app will process all texts and provide aggregate visualizations.
262+
263+
A sample CSV file is available for download [here](https://huggingface.co/spaces/MMADS/MoralFoundationsClassifier-app/tree/main/examples
247264
""")
248265

249266
csv_input = gr.File(
@@ -257,12 +274,15 @@ def process_csv(file):
257274

258275
with gr.Row():
259276
bar_output = gr.Plot(label="Average Scores")
260-
heatmap_output = gr.Plot(label="Scores Heatmap")
277+
heatmap_output = gr.Plot(label="Scores Heatmap (First 20 texts)")
278+
279+
# Add download component
280+
download_output = gr.File(label="Download Results", visible=True)
261281

262282
process_btn.click(
263283
fn=process_csv,
264284
inputs=csv_input,
265-
outputs=[summary_output, bar_output, heatmap_output]
285+
outputs=[summary_output, bar_output, heatmap_output, download_output]
266286
)
267287

268288
gr.Markdown("""

0 commit comments

Comments
 (0)