@@ -14,12 +14,12 @@ import (
1414 "errors"
1515 "fmt"
1616 "io"
17+ "io/fs"
1718 "log"
1819 "math/rand"
1920 "net/http"
2021 "os"
2122 "os/exec"
22- "path"
2323 "path/filepath"
2424 "runtime"
2525 "strconv"
@@ -54,6 +54,7 @@ const (
5454 randomSeedVariableName = "randomSeed"
5555 nowVariableName = "now"
5656 ModeEnvironmentVariableName = "AZURE_TEST_MODE"
57+ recordingAssetConfigName = "assets.json"
5758)
5859
5960// Inspired by https://stackoverflow.com/questions/22892120/how-to-generate-a-random-string-of-a-fixed-length-in-go
@@ -574,7 +575,95 @@ func (r RecordingOptions) baseURL() string {
574575}
575576
576577func getTestId (pathToRecordings string , t * testing.T ) string {
577- return path .Join (pathToRecordings , "recordings" , t .Name ()+ ".json" )
578+ return filepath .Join (pathToRecordings , "recordings" , t .Name ()+ ".json" )
579+ }
580+
581+ func getGitRoot (fromPath string ) (string , error ) {
582+ absPath , err := filepath .Abs (fromPath )
583+ if err != nil {
584+ return "" , err
585+ }
586+ cmd := exec .Command ("git" , "rev-parse" , "--show-toplevel" )
587+ cmd .Dir = absPath
588+
589+ root , err := cmd .CombinedOutput ()
590+ if err != nil {
591+ return "" , fmt .Errorf ("Unable to find git root for path '%s'" , absPath )
592+ }
593+
594+ // Wrap with Abs() to get os-specific path separators to support sub-path matching
595+ return filepath .Abs (strings .TrimSpace (string (root )))
596+ }
597+
598+ // Traverse up from a recording path until an asset config file is found.
599+ // Stop searching when the root of the git repository is reached.
600+ func findAssetsConfigFile (fromPath string , untilPath string ) (string , error ) {
601+ absPath , err := filepath .Abs (fromPath )
602+ if err != nil {
603+ return "" , err
604+ }
605+ assetConfigPath := filepath .Join (absPath , recordingAssetConfigName )
606+
607+ if _ , err := os .Stat (assetConfigPath ); err == nil {
608+ return assetConfigPath , nil
609+ } else if ! errors .Is (err , fs .ErrNotExist ) {
610+ return "" , err
611+ }
612+
613+ if absPath == untilPath {
614+ return "" , nil
615+ }
616+
617+ parentDir := filepath .Dir (absPath )
618+ // This shouldn't be hit due to checks in getGitRoot, but it can't hurt to be defensive
619+ if parentDir == absPath || parentDir == "." {
620+ return "" , nil
621+ }
622+
623+ return findAssetsConfigFile (parentDir , untilPath )
624+ }
625+
626+ // Returns absolute and relative paths to an asset configuration file, or an error.
627+ func getAssetsConfigLocation (pathToRecordings string ) (string , string , error ) {
628+ cwd , err := os .Getwd ()
629+ if err != nil {
630+ return "" , "" , err
631+ }
632+ gitRoot , err := getGitRoot (cwd )
633+ if err != nil {
634+ return "" , "" , err
635+ }
636+ abs , err := findAssetsConfigFile (filepath .Join (gitRoot , pathToRecordings ), gitRoot )
637+ if err != nil {
638+ return "" , "" , err
639+ }
640+
641+ // Pass a path relative to the git root to test proxy so that paths
642+ // can be resolved when the repo root is mounted as a volume in a container
643+ rel := strings .Replace (abs , gitRoot , "" , 1 )
644+ rel = strings .TrimLeft (rel , string (os .PathSeparator ))
645+ return abs , rel , nil
646+ }
647+
648+ func requestStart (url string , testId string , assetConfigLocation string ) (* http.Response , error ) {
649+ req , err := http .NewRequest ("POST" , url , nil )
650+ if err != nil {
651+ return nil , err
652+ }
653+
654+ req .Header .Set ("Content-Type" , "application/json" )
655+ reqBody := map [string ]string {"x-recording-file" : testId }
656+ if assetConfigLocation != "" {
657+ reqBody ["x-recording-assets-file" ] = assetConfigLocation
658+ }
659+ marshalled , err := json .Marshal (reqBody )
660+ if err != nil {
661+ return nil , err
662+ }
663+ req .Body = io .NopCloser (bytes .NewReader (marshalled ))
664+ req .ContentLength = int64 (len (marshalled ))
665+
666+ return client .Do (req )
578667}
579668
580669// Start tells the test proxy to begin accepting requests for a given test
@@ -595,25 +684,27 @@ func Start(t *testing.T, pathToRecordings string, options *RecordingOptions) err
595684
596685 testId := getTestId (pathToRecordings , t )
597686
598- url := fmt .Sprintf ("%s/%s/start" , options .baseURL (), recordMode )
599-
600- req , err := http .NewRequest ("POST" , url , nil )
687+ absAssetLocation , relAssetLocation , err := getAssetsConfigLocation (pathToRecordings )
601688 if err != nil {
602689 return err
603690 }
604691
605- req .Header .Set ("Content-Type" , "application/json" )
606- marshalled , err := json .Marshal (map [string ]string {"x-recording-file" : testId })
607- if err != nil {
608- return err
609- }
610- req .Body = io .NopCloser (bytes .NewReader (marshalled ))
611- req .ContentLength = int64 (len (marshalled ))
692+ url := fmt .Sprintf ("%s/%s/start" , options .baseURL (), recordMode )
612693
613- resp , err := client .Do (req )
614- if err != nil {
694+ var resp * http.Response
695+ if absAssetLocation == "" {
696+ resp , err = requestStart (url , testId , "" )
697+ if err != nil {
698+ return err
699+ }
700+ } else if resp , err = requestStart (url , testId , absAssetLocation ); err != nil {
615701 return err
702+ } else if resp .StatusCode >= 400 {
703+ if resp , err = requestStart (url , testId , relAssetLocation ); err != nil {
704+ return err
705+ }
616706 }
707+
617708 recId := resp .Header .Get (IDHeader )
618709 if recId == "" {
619710 b , err := io .ReadAll (resp .Body )
0 commit comments