Skip to content

Commit 36d4590

Browse files
committed
Prompt to save session
1 parent 207ae88 commit 36d4590

3 files changed

Lines changed: 66 additions & 16 deletions

File tree

src/repl/conf.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@ func NewConfigOptions() *ConfigOptions {
6969
co.RegisterOption("history", BooleanOption, "Enable REPL history", "true")
7070
// Enable automatic AI-generated session topics (#topic)
7171
co.RegisterOption("aitopic", BooleanOption, "Enable automatic AI-generated session topics", "false")
72+
// Set session save behavior on exit: always, never, or prompt
73+
co.RegisterOption("session_save", StringOption, "Session save behavior on exit: always, never, or prompt", "prompt")
7274

7375
co.initialized = true
7476

src/repl/repl.go

Lines changed: 58 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,46 @@ type StreamingClient interface {
7272
StreamChat(ctx context.Context, messages []Message) (<-chan string, <-chan error)
7373
}
7474

75+
// AskYesNo prompts the user with a yes/no question, defaulting to 'y' or 'n'.
76+
// Returns true for yes, false for no.
77+
func AskYesNo(question string, defaultVal rune) bool {
78+
defaultVal = rune(strings.ToLower(string(defaultVal))[0])
79+
if defaultVal != 'y' && defaultVal != 'n' {
80+
panic("default value must be 'y' or 'n'")
81+
}
82+
83+
var defaultText string
84+
if defaultVal == 'y' {
85+
defaultText = "[Y/n]"
86+
} else {
87+
defaultText = "[y/N]"
88+
}
89+
90+
fmt.Printf("%s %s ", question, defaultText)
91+
92+
// Put terminal in raw mode
93+
oldState, err := term.MakeRaw(int(os.Stdin.Fd()))
94+
if err != nil {
95+
panic(err)
96+
}
97+
defer term.Restore(int(os.Stdin.Fd()), oldState)
98+
99+
// Read one byte
100+
var buf [1]byte
101+
_, err = os.Stdin.Read(buf[:])
102+
if err != nil {
103+
panic(err)
104+
}
105+
106+
c := buf[0]
107+
if c == '\r' || c == '\n' { // Enter pressed -> use default
108+
return defaultVal == 'y'
109+
}
110+
111+
c = byte(strings.ToLower(string(c))[0])
112+
return c == 'y'
113+
}
114+
75115
func NewREPL(config *Config) (*REPL, error) {
76116
ctx, cancel := context.WithCancel(context.Background())
77117

@@ -386,17 +426,25 @@ func (r *REPL) cleanup() {
386426
// Auto-save the chat session if history is enabled and messages exist,
387427
// updating the current session or creating a new one if none selected
388428
if r.config.options.GetBool("history") && len(r.messages) > 0 {
389-
var name string
390-
if r.currentSession != "" {
391-
name = r.currentSession
392-
} else {
393-
// name = time.Now().Format("20060102150405")
394-
name = time.Now().Format("05041502012006")
395-
}
396-
if err := r.saveSession(name); err != nil {
397-
fmt.Fprintf(os.Stderr, "Error auto-saving session: %v\n", err)
429+
mode := r.config.options.Get("session_save")
430+
if mode != "never" {
431+
var name string
432+
if r.currentSession != "" {
433+
name = r.currentSession
434+
} else {
435+
// name = time.Now().Format("20060102150405")
436+
name = time.Now().Format("05041502012006")
437+
}
438+
if mode == "prompt" {
439+
if !AskYesNo("Save session %q? (Y/n) ", 'y') {
440+
return
441+
}
442+
}
443+
if err := r.saveSession(name); err != nil {
444+
fmt.Fprintf(os.Stderr, "Error auto-saving session: %v\n", err)
445+
}
446+
r.currentSession = name
398447
}
399-
r.currentSession = name
400448
}
401449
}
402450

src/repl/tools.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -250,18 +250,18 @@ func (r *REPL) toolStep(toolPrompt string, input string, ctx string, toolList st
250250
*/
251251
messages := []Message{{"user", query}}
252252
/*
253-
fmt.Println("-------------------------8<-------------------------")
254-
fmt.Println(query)
255-
fmt.Println("------------------------->8-------------------------")
253+
fmt.Println("-------------------------8<-------------------------")
254+
fmt.Println(query)
255+
fmt.Println("------------------------->8-------------------------")
256256
*/
257257
responseText, err := r.currentClient.SendMessage(messages, false)
258258
if err != nil {
259259
return PlanResponse{}, "", fmt.Errorf("failed to get response for tools: %v", err)
260260
}
261261
/*
262-
fmt.Println("-------------------------8<-------------------------")
263-
fmt.Println(responseText)
264-
fmt.Println("------------------------->8-------------------------")
262+
fmt.Println("-------------------------8<-------------------------")
263+
fmt.Println(responseText)
264+
fmt.Println("------------------------->8-------------------------")
265265
*/
266266
// strip out any internal reasoning between <think>...</think> before processing
267267
reThink := regexp.MustCompile(`(?s)\s*<think>.*?</think>\s*`)

0 commit comments

Comments
 (0)