-
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtable.go
More file actions
172 lines (150 loc) · 4.67 KB
/
table.go
File metadata and controls
172 lines (150 loc) · 4.67 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
// File: quantest/table.go
package quantest
import (
"bytes"
"fmt"
"log"
"sort"
"github.com/charmbracelet/lipgloss"
"github.com/fatih/color"
"github.com/olekukonko/tablewriter"
"github.com/olekukonko/tablewriter/renderer"
"github.com/olekukonko/tablewriter/tw"
)
// GenerateQuantTable generates a quantisation table for a given model.
//
// Parameters:
// - modelID: A string representing the model ID.
// - fitsVRAM: A float64 representing the available VRAM in GB.
// - ollamaModelInfo: A pointer to an OllamaModelInfo struct.
//
// Returns:
// - QuantResultTable: A QuantResultTable struct containing the quantisation results.
// - error: An error if the quantisation fails.
//
// Example:
//
// table, _ := GenerateQuantTable("llama3.1", 24.0, nil)
func GenerateQuantTable(config ModelConfig, fitsVRAM float64) (QuantResultTable, error) {
if fitsVRAM == 0 {
var err error
fitsVRAM, err = GetAvailableMemory()
if err != nil {
log.Printf("Failed to get available memory: %v. Using default value.", err)
fitsVRAM = 24 // Default to 24GB if we can't determine available memory
}
log.Printf("Using %.2f GB as available memory for VRAM estimation", fitsVRAM)
}
table := QuantResultTable{ModelID: config.ModelName, FitsVRAM: fitsVRAM}
contextSizes := []int{2048, 8192, 16384, 32768, 49152, 65536}
if !config.IsOllama {
_, err := GetHFModelConfig(config.ModelName)
if err != nil {
return QuantResultTable{}, err
}
}
for quantType, bpw := range GGUFMapping {
var result QuantResult
result.QuantType = quantType
result.BPW = bpw
result.Contexts = make(map[int]ContextVRAM)
for _, context := range contextSizes {
vramFP16, err := CalculateVRAM(config, bpw, context, KVCacheFP16)
if err != nil {
return QuantResultTable{}, err
}
vramQ8_0, err := CalculateVRAM(config, bpw, context, KVCacheQ8_0)
if err != nil {
return QuantResultTable{}, err
}
vramQ4_0, err := CalculateVRAM(config, bpw, context, KVCacheQ4_0)
if err != nil {
return QuantResultTable{}, err
}
result.Contexts[context] = ContextVRAM{
VRAM: vramFP16,
VRAMQ8_0: vramQ8_0,
VRAMQ4_0: vramQ4_0,
}
}
table.Results = append(table.Results, result)
}
// Sort the results from lowest BPW to highest
sort.Slice(table.Results, func(i, j int) bool {
return table.Results[i].BPW < table.Results[j].BPW
})
return table, nil
}
// PrintFormattedTable prints a formatted table of the quantisation results.
//
// Parameters:
// - table: A QuantResultTable struct containing the quantisation results.
//
// Returns:
// - string: A string containing the formatted table.
//
// Example:
//
// table, _ := GenerateQuantTable("llama3.1", 24.0, nil)
func PrintFormattedTable(table QuantResultTable) string {
var buf bytes.Buffer
// Configure colors for the table
colorCfg := renderer.ColorizedConfig{
Header: renderer.Tint{
FG: renderer.Colors{color.FgHiWhite}, // Bright white headers
},
Border: renderer.Tint{
FG: renderer.Colors{color.FgWhite}, // White borders
},
Separator: renderer.Tint{
FG: renderer.Colors{color.FgWhite}, // White separators
},
}
// Create a new table with the colorized renderer and configure it
rendition := tw.Rendition{
Borders: tw.Border{
Left: tw.On,
Top: tw.Off,
Right: tw.On,
Bottom: tw.Off,
},
Symbols: tw.NewSymbols(tw.StyleLight),
Settings: tw.Settings{
Separators: tw.Separators{
BetweenColumns: tw.On,
},
},
}
tw := tablewriter.NewTable(&buf,
tablewriter.WithRenderer(renderer.NewColorized(colorCfg)),
tablewriter.WithRendition(rendition),
)
// Set the header
tw.Header([]string{"Quant|Ctx", "BPW", "2K", "8K", "16K", "32K", "49K", "64K"})
// Prepare data rows
for _, result := range table.Results {
row := []string{
result.QuantType,
fmt.Sprintf("%.2f", result.BPW),
}
// Add VRAM estimates for each context size
contextSizes := []int{2048, 8192, 16384, 32768, 49152, 65536}
for _, context := range contextSizes {
vram := result.Contexts[context]
fp16Str := getColouredVRAM(vram.VRAM, fmt.Sprintf("%.1f", vram.VRAM), table.FitsVRAM)
if context >= 16384 {
q8Str := getColouredVRAM(vram.VRAMQ8_0, fmt.Sprintf("%.1f", vram.VRAMQ8_0), table.FitsVRAM)
q4Str := getColouredVRAM(vram.VRAMQ4_0, fmt.Sprintf("%.1f", vram.VRAMQ4_0), table.FitsVRAM)
combinedStr := fmt.Sprintf("%s(%s,%s)", fp16Str, q8Str, q4Str)
row = append(row, combinedStr)
} else {
combinedStr := fp16Str
row = append(row, combinedStr)
}
}
tw.Append(row)
}
// Render the table
tw.Render()
return lipgloss.NewStyle().Foreground(lipgloss.Color("#ffffff")).Render(fmt.Sprintf("📊 VRAM Estimation for Model: %s\n\n%s", table.ModelID, buf.String()))
}