99from functools import lru_cache
1010
1111# Cache the model loading
12- # @gr.cache
1312@lru_cache (maxsize = 1 )
1413def load_model ():
1514 model_path = "MMADS/MoralFoundationsClassifier"
@@ -144,34 +143,48 @@ def process_text(text):
144143
145144 return scores_text , bar_chart
146145
147- def process_csv (file ):
146+ def process_csv (file , progress = gr . Progress () ):
148147 """Process CSV file with multiple texts"""
149148 if file is None :
150- return "Please upload a CSV file" , None , None
149+ return "Please upload a CSV file" , None , None , None
151150
152151 try :
153152 # Read CSV
154153 df = pd .read_csv (file .name )
155154
156155 if 'text' not in df .columns :
157- return "Error: CSV must contain a 'text' column" , None , None
156+ return "Error: CSV must contain a 'text' column" , None , None , None
158157
159158 texts = df ['text' ].tolist ()
160159
161160 # Load model and process in batches
161+ progress (0 , desc = "Loading model..." )
162162 model , tokenizer , label_names = load_model ()
163163
164164 # Process in batches of 32
165165 batch_size = 32
166166 all_results = []
167+ total_batches = (len (texts ) + batch_size - 1 ) // batch_size
167168
168169 for i in range (0 , len (texts ), batch_size ):
170+ batch_num = i // batch_size + 1
171+ progress (batch_num / total_batches , desc = f"Processing batch { batch_num } /{ total_batches } " )
172+
169173 batch_texts = texts [i :i + batch_size ]
170174 batch_results = predict_batch (batch_texts , model , tokenizer , label_names )
171175 all_results .extend (batch_results )
172176
177+ progress (0.9 , desc = "Creating visualizations..." )
178+
173179 # Create summary
174180 summary = f"**Processed { len (texts )} texts**\n \n "
181+ summary += "**Average Scores Across All Texts:**\n \n "
182+
183+ # Calculate average scores
184+ avg_scores = {}
185+ for label in label_names :
186+ avg_scores [label ] = np .mean ([r ['scores' ][label ] for r in all_results ])
187+ summary += f"{ label .replace ('_' , ' ' ).title ()} : { avg_scores [label ]:.4f} \n "
175188
176189 # Create visualizations
177190 bar_chart = create_visualization (all_results )
@@ -185,12 +198,14 @@ def process_csv(file):
185198 } for r in all_results
186199 ])
187200
188- results_df .to_csv ('results.csv' , index = False )
201+ # Save to a temporary file and return the path
202+ output_path = "results.csv"
203+ results_df .to_csv (output_path , index = False )
189204
190- return summary + "Results saved to results.csv" , bar_chart , heatmap
205+ return summary , bar_chart , heatmap , output_path
191206
192207 except Exception as e :
193- return f"Error processing CSV: { str (e )} " , None , None
208+ return f"Error processing CSV: { str (e )} " , None , None , None
194209
195210# Create example texts
196211example_texts = [
@@ -244,6 +259,8 @@ def process_csv(file):
244259 gr .Markdown ("""
245260 Upload a CSV file with a 'text' column containing the texts to analyze.
246261 The app will process all texts and provide aggregate visualizations.
262+
263+ A sample CSV file is available for download [here](https://huggingface.co/spaces/MMADS/MoralFoundationsClassifier-app/tree/main/examples
247264 """ )
248265
249266 csv_input = gr .File (
@@ -257,12 +274,15 @@ def process_csv(file):
257274
258275 with gr .Row ():
259276 bar_output = gr .Plot (label = "Average Scores" )
260- heatmap_output = gr .Plot (label = "Scores Heatmap" )
277+ heatmap_output = gr .Plot (label = "Scores Heatmap (First 20 texts)" )
278+
279+ # Add download component
280+ download_output = gr .File (label = "Download Results" , visible = True )
261281
262282 process_btn .click (
263283 fn = process_csv ,
264284 inputs = csv_input ,
265- outputs = [summary_output , bar_output , heatmap_output ]
285+ outputs = [summary_output , bar_output , heatmap_output , download_output ]
266286 )
267287
268288 gr .Markdown ("""
0 commit comments