Skip to content

Commit

Permalink
feat: Allow forcing either IPV4 or IPV6 network connections only
Browse files Browse the repository at this point in the history
  • Loading branch information
prantlf committed Jun 5, 2024
1 parent 54dbb29 commit 5c0a9e4
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 2 deletions.
2 changes: 2 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ ENTRYPOINT ["/ovai"]
ARG DEBUG=ovai,ovai:srv
ENV DEBUG=${DEBUG}
ENV PORT=22434
ENV NETWORK=
ENV OLLAMA_ORIGIN=

# HEALTHCHECK --interval=5m \
# CMD /healthchk http://localhost:22434/api/ping || exit 1
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,13 @@ Set the environment variable `DEBUG` to one or more strings separated by commas

Set the environment variable `OLLAMA_ORIGIN` to the origin of the `ollama` service to enable forwarding to `ollama`. If the requested model doesn't start with `gemini`, `chat-bison`, `text-bison` or `textembedding-gecko`, the request will be forwarded to the `ollama` service. This can be used for using `ovai` as the single service with the `ollama` interface, which recognises both `Vertex AI` and `ollama` models.

Set the environment variable `NETWORK` to enforce IPV4 or IPV6. The default behaviour is to depend on tHe [Happy Eyeballs] implementation in Go and in the underlying OS. valid values:

| `NETWORK` value | What will be used |
|:----------------|:---------------------------------------------|
| `IPV4` | enforce the network connection via IPV4 only |
| `IPV6` | enforce the network connection via IPV6 only |

### Docker

For example, run a container for testing purposes with verbose logging, deleted on exit, exposing the port 22434:
Expand Down Expand Up @@ -226,6 +233,7 @@ Licensed under the [MIT License].
[GitHub Releases]: https://github.com/prantlf/ovai/releases/
[Go]: https://go.dev
[default model parameters]: ./model-defaults.json
[Happy Eyeballs] :https://en.wikipedia.org/wiki/Happy_Eyeballs
[docker-compose.yml]: ./docker-compose.yml
[REST API documentation]: https://github.com/ollama/ollama/blob/main/docs/api.md
[embedding models]: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/text-embeddings#model_versions
Expand Down
63 changes: 61 additions & 2 deletions internal/web/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@ package web

import (
"bytes"
"context"
"encoding/json"
"errors"
"io"
"net"
"net/http"
"os"

"github.com/prantlf/ovai/internal/log"
)
Expand Down Expand Up @@ -121,8 +124,64 @@ func CreateRawPostRequest(url string, input []byte) (*http.Request, error) {
return req, nil
}

func DispatchRequest(req *http.Request, output interface{}) (int, error) {
// type happyEyeballs int

// const (
// ipDefault happyEyeballs = iota + 1
// ipV4
// ipV6
// )

// var happyEyeballNames = [...]string{"Default", "IPV4", "IPV6"}

// func (h happyEyeballs) String() string {
// return happyEyeballNames[h-1]
// }

// func (h happyEyeballs) EnumIndex() int {
// return int(h)
// }

// func parseHappyEyeballs(input string) (happyEyeballs, error) {
// for index, value := range happyEyeballNames {
// if value == input {
// return happyEyeballs(index), nil
// }
// }
// return 0, fmt.Errorf("invalid enum value: %q", input)
// }

var dialer net.Dialer
var networkVersion = initDialer()

func initDialer() string {
network := os.Getenv("NETWORK")
if len(network) > 0 {
if network == "IPV4" {
return "tcp4"
}
if network == "IPV6" {
return "tcp6"
}
log.Ftl("Invalid value of NETWORK variable: %q", network)
}
return ""
}

func createHttpClient() *http.Client {
client := &http.Client{}
if len(networkVersion) > 0 {
transport := http.DefaultTransport.(*http.Transport).Clone()
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.DialContext(ctx, networkVersion, addr)
}
client.Transport = transport
}
return client
}

func DispatchRequest(req *http.Request, output interface{}) (int, error) {
client := createHttpClient()
res, err := client.Do(req)
if err != nil {
log.Dbg("making request failed: %v", err)
Expand Down Expand Up @@ -170,7 +229,7 @@ func DispatchRequest(req *http.Request, output interface{}) (int, error) {
}

func DispatchRawRequest(req *http.Request) (int, []byte, error) {
client := &http.Client{}
client := createHttpClient()
res, err := client.Do(req)
if err != nil {
log.Dbg("making request failed: %v", err)
Expand Down

0 comments on commit 5c0a9e4

Please sign in to comment.