diff options
Diffstat (limited to 'forged/internal/common/bare')
-rw-r--r-- | forged/internal/common/bare/LICENSE | 203 | ||||
-rw-r--r-- | forged/internal/common/bare/doc.go | 8 | ||||
-rw-r--r-- | forged/internal/common/bare/errors.go | 20 | ||||
-rw-r--r-- | forged/internal/common/bare/limit.go | 58 | ||||
-rw-r--r-- | forged/internal/common/bare/marshal.go | 311 | ||||
-rw-r--r-- | forged/internal/common/bare/reader.go | 190 | ||||
-rw-r--r-- | forged/internal/common/bare/unions.go | 81 | ||||
-rw-r--r-- | forged/internal/common/bare/unmarshal.go | 362 | ||||
-rw-r--r-- | forged/internal/common/bare/varint.go | 30 | ||||
-rw-r--r-- | forged/internal/common/bare/writer.go | 121 |
10 files changed, 1384 insertions, 0 deletions
diff --git a/forged/internal/common/bare/LICENSE b/forged/internal/common/bare/LICENSE new file mode 100644 index 0000000..6b0b127 --- /dev/null +++ b/forged/internal/common/bare/LICENSE @@ -0,0 +1,203 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + diff --git a/forged/internal/common/bare/doc.go b/forged/internal/common/bare/doc.go new file mode 100644 index 0000000..2f12f55 --- /dev/null +++ b/forged/internal/common/bare/doc.go @@ -0,0 +1,8 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright (c) 2025 Runxi Yu <https://runxiyu.org> + +// Package bare provides primitives to encode and decode BARE messages. +// +// There is no guarantee that this is compatible with the upstream +// implementation at https://git.sr.ht/~sircmpwn/go-bare. +package bare diff --git a/forged/internal/common/bare/errors.go b/forged/internal/common/bare/errors.go new file mode 100644 index 0000000..39c951a --- /dev/null +++ b/forged/internal/common/bare/errors.go @@ -0,0 +1,20 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright (c) 2025 Drew Devault <https://drewdevault.com> + +package bare + +import ( + "errors" + "fmt" + "reflect" +) + +var ErrInvalidStr = errors.New("String contains invalid UTF-8 sequences") + +type UnsupportedTypeError struct { + Type reflect.Type +} + +func (e *UnsupportedTypeError) Error() string { + return fmt.Sprintf("Unsupported type for marshaling: %s\n", e.Type.String()) +} diff --git a/forged/internal/common/bare/limit.go b/forged/internal/common/bare/limit.go new file mode 100644 index 0000000..212bc05 --- /dev/null +++ b/forged/internal/common/bare/limit.go @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright (c) 2025 Drew Devault <https://drewdevault.com> + +package bare + +import ( + "errors" + "io" +) + +var ( + maxUnmarshalBytes uint64 = 1024 * 1024 * 32 /* 32 MiB */ + maxArrayLength uint64 = 1024 * 4 /* 4096 elements */ + maxMapSize uint64 = 1024 +) + +// MaxUnmarshalBytes sets the maximum size of a message decoded by unmarshal. +// By default, this is set to 32 MiB. +func MaxUnmarshalBytes(bytes uint64) { + maxUnmarshalBytes = bytes +} + +// MaxArrayLength sets maximum number of elements in array. Defaults to 4096 elements +func MaxArrayLength(length uint64) { + maxArrayLength = length +} + +// MaxMapSize sets maximum size of map. Defaults to 1024 key/value pairs +func MaxMapSize(size uint64) { + maxMapSize = size +} + +// Use MaxUnmarshalBytes to prevent this error from occuring on messages which +// are large by design. +var ErrLimitExceeded = errors.New("Maximum message size exceeded") + +// Identical to io.LimitedReader, except it returns our custom error instead of +// EOF if the limit is reached. +type limitedReader struct { + R io.Reader + N uint64 +} + +func (l *limitedReader) Read(p []byte) (n int, err error) { + if l.N <= 0 { + return 0, ErrLimitExceeded + } + if uint64(len(p)) > l.N { + p = p[0:l.N] + } + n, err = l.R.Read(p) + l.N -= uint64(n) + return +} + +func newLimitedReader(r io.Reader) *limitedReader { + return &limitedReader{r, maxUnmarshalBytes} +} diff --git a/forged/internal/common/bare/marshal.go b/forged/internal/common/bare/marshal.go new file mode 100644 index 0000000..1ce942d --- /dev/null +++ b/forged/internal/common/bare/marshal.go @@ -0,0 +1,311 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright (c) 2025 Drew Devault <https://drewdevault.com> + +package bare + +import ( + "bytes" + "errors" + "fmt" + "reflect" + "sync" +) + +// A type which implements this interface will be responsible for marshaling +// itself when encountered. +type Marshalable interface { + Marshal(w *Writer) error +} + +var encoderBufferPool = sync.Pool{ + New: func() interface{} { + buf := &bytes.Buffer{} + buf.Grow(32) + return buf + }, +} + +// Marshals a value (val, which must be a pointer) into a BARE message. +// +// The encoding of each struct field can be customized by the format string +// stored under the "bare" key in the struct field's tag. +// +// As a special case, if the field tag is "-", the field is always omitted. +func Marshal(val interface{}) ([]byte, error) { + // reuse buffers from previous serializations + b := encoderBufferPool.Get().(*bytes.Buffer) + defer func() { + b.Reset() + encoderBufferPool.Put(b) + }() + + w := NewWriter(b) + err := MarshalWriter(w, val) + + msg := make([]byte, b.Len()) + copy(msg, b.Bytes()) + + return msg, err +} + +// Marshals a value (val, which must be a pointer) into a BARE message and +// writes it to a Writer. See Marshal for details. +func MarshalWriter(w *Writer, val interface{}) error { + t := reflect.TypeOf(val) + v := reflect.ValueOf(val) + if t.Kind() != reflect.Ptr { + return errors.New("Expected val to be pointer type") + } + + return getEncoder(t.Elem())(w, v.Elem()) +} + +type encodeFunc func(w *Writer, v reflect.Value) error + +var encodeFuncCache sync.Map // map[reflect.Type]encodeFunc + +// get decoder from cache +func getEncoder(t reflect.Type) encodeFunc { + if f, ok := encodeFuncCache.Load(t); ok { + return f.(encodeFunc) + } + + f := encoderFunc(t) + encodeFuncCache.Store(t, f) + return f +} + +var marshalableInterface = reflect.TypeOf((*Unmarshalable)(nil)).Elem() + +func encoderFunc(t reflect.Type) encodeFunc { + if reflect.PointerTo(t).Implements(marshalableInterface) { + return func(w *Writer, v reflect.Value) error { + uv := v.Addr().Interface().(Marshalable) + return uv.Marshal(w) + } + } + + if t.Kind() == reflect.Interface && t.Implements(unionInterface) { + return encodeUnion(t) + } + + switch t.Kind() { + case reflect.Ptr: + return encodeOptional(t.Elem()) + case reflect.Struct: + return encodeStruct(t) + case reflect.Array: + return encodeArray(t) + case reflect.Slice: + return encodeSlice(t) + case reflect.Map: + return encodeMap(t) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return encodeUint + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return encodeInt + case reflect.Float32, reflect.Float64: + return encodeFloat + case reflect.Bool: + return encodeBool + case reflect.String: + return encodeString + } + + return func(w *Writer, v reflect.Value) error { + return &UnsupportedTypeError{v.Type()} + } +} + +func encodeOptional(t reflect.Type) encodeFunc { + return func(w *Writer, v reflect.Value) error { + if v.IsNil() { + return w.WriteBool(false) + } + + if err := w.WriteBool(true); err != nil { + return err + } + + return getEncoder(t)(w, v.Elem()) + } +} + +func encodeStruct(t reflect.Type) encodeFunc { + n := t.NumField() + encoders := make([]encodeFunc, n) + for i := 0; i < n; i++ { + field := t.Field(i) + if field.Tag.Get("bare") == "-" { + continue + } + encoders[i] = getEncoder(field.Type) + } + + return func(w *Writer, v reflect.Value) error { + for i := 0; i < n; i++ { + if encoders[i] == nil { + continue + } + err := encoders[i](w, v.Field(i)) + if err != nil { + return err + } + } + return nil + } +} + +func encodeArray(t reflect.Type) encodeFunc { + f := getEncoder(t.Elem()) + len := t.Len() + + return func(w *Writer, v reflect.Value) error { + for i := 0; i < len; i++ { + if err := f(w, v.Index(i)); err != nil { + return err + } + } + return nil + } +} + +func encodeSlice(t reflect.Type) encodeFunc { + elem := t.Elem() + f := getEncoder(elem) + + return func(w *Writer, v reflect.Value) error { + if err := w.WriteUint(uint64(v.Len())); err != nil { + return err + } + + for i := 0; i < v.Len(); i++ { + if err := f(w, v.Index(i)); err != nil { + return err + } + } + return nil + } +} + +func encodeMap(t reflect.Type) encodeFunc { + keyType := t.Key() + keyf := getEncoder(keyType) + + valueType := t.Elem() + valf := getEncoder(valueType) + + return func(w *Writer, v reflect.Value) error { + if err := w.WriteUint(uint64(v.Len())); err != nil { + return err + } + + iter := v.MapRange() + for iter.Next() { + if err := keyf(w, iter.Key()); err != nil { + return err + } + if err := valf(w, iter.Value()); err != nil { + return err + } + } + return nil + } +} + +func encodeUnion(t reflect.Type) encodeFunc { + ut, ok := unionRegistry[t] + if !ok { + return func(w *Writer, v reflect.Value) error { + return fmt.Errorf("Union type %s is not registered", t.Name()) + } + } + + encoders := make(map[uint64]encodeFunc) + for tag, t := range ut.types { + encoders[tag] = getEncoder(t) + } + + return func(w *Writer, v reflect.Value) error { + t := v.Elem().Type() + if t.Kind() == reflect.Ptr { + // If T is a valid union value type, *T is valid too. + t = t.Elem() + v = v.Elem() + } + tag, ok := ut.tags[t] + if !ok { + return fmt.Errorf("Invalid union value: %s", v.Elem().String()) + } + + if err := w.WriteUint(tag); err != nil { + return err + } + + return encoders[tag](w, v.Elem()) + } +} + +func encodeUint(w *Writer, v reflect.Value) error { + switch getIntKind(v.Type()) { + case reflect.Uint: + return w.WriteUint(v.Uint()) + + case reflect.Uint8: + return w.WriteU8(uint8(v.Uint())) + + case reflect.Uint16: + return w.WriteU16(uint16(v.Uint())) + + case reflect.Uint32: + return w.WriteU32(uint32(v.Uint())) + + case reflect.Uint64: + return w.WriteU64(uint64(v.Uint())) + } + + panic("not uint") +} + +func encodeInt(w *Writer, v reflect.Value) error { + switch getIntKind(v.Type()) { + case reflect.Int: + return w.WriteInt(v.Int()) + + case reflect.Int8: + return w.WriteI8(int8(v.Int())) + + case reflect.Int16: + return w.WriteI16(int16(v.Int())) + + case reflect.Int32: + return w.WriteI32(int32(v.Int())) + + case reflect.Int64: + return w.WriteI64(int64(v.Int())) + } + + panic("not int") +} + +func encodeFloat(w *Writer, v reflect.Value) error { + switch v.Type().Kind() { + case reflect.Float32: + return w.WriteF32(float32(v.Float())) + case reflect.Float64: + return w.WriteF64(v.Float()) + } + + panic("not float") +} + +func encodeBool(w *Writer, v reflect.Value) error { + return w.WriteBool(v.Bool()) +} + +func encodeString(w *Writer, v reflect.Value) error { + if v.Kind() != reflect.String { + panic("not string") + } + return w.WriteString(v.String()) +} diff --git a/forged/internal/common/bare/reader.go b/forged/internal/common/bare/reader.go new file mode 100644 index 0000000..028a7aa --- /dev/null +++ b/forged/internal/common/bare/reader.go @@ -0,0 +1,190 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright (c) 2025 Drew Devault <https://drewdevault.com> + +package bare + +import ( + "encoding/binary" + "fmt" + "io" + "math" + "unicode/utf8" + + "go.lindenii.runxiyu.org/forge/forged/internal/common/misc" +) + +type byteReader interface { + io.Reader + io.ByteReader +} + +// A Reader for BARE primitive types. +type Reader struct { + base byteReader + scratch [8]byte +} + +type simpleByteReader struct { + io.Reader + scratch [1]byte +} + +func (r simpleByteReader) ReadByte() (byte, error) { + // using reference type here saves us allocations + _, err := r.Read(r.scratch[:]) + return r.scratch[0], err +} + +// Returns a new BARE primitive reader wrapping the given io.Reader. +func NewReader(base io.Reader) *Reader { + br, ok := base.(byteReader) + if !ok { + br = simpleByteReader{Reader: base} + } + return &Reader{base: br} +} + +func (r *Reader) ReadUint() (uint64, error) { + x, err := binary.ReadUvarint(r.base) + if err != nil { + return x, err + } + return x, nil +} + +func (r *Reader) ReadU8() (uint8, error) { + return r.base.ReadByte() +} + +func (r *Reader) ReadU16() (uint16, error) { + var i uint16 + if _, err := io.ReadAtLeast(r.base, r.scratch[:2], 2); err != nil { + return i, err + } + return binary.LittleEndian.Uint16(r.scratch[:]), nil +} + +func (r *Reader) ReadU32() (uint32, error) { + var i uint32 + if _, err := io.ReadAtLeast(r.base, r.scratch[:4], 4); err != nil { + return i, err + } + return binary.LittleEndian.Uint32(r.scratch[:]), nil +} + +func (r *Reader) ReadU64() (uint64, error) { + var i uint64 + if _, err := io.ReadAtLeast(r.base, r.scratch[:8], 8); err != nil { + return i, err + } + return binary.LittleEndian.Uint64(r.scratch[:]), nil +} + +func (r *Reader) ReadInt() (int64, error) { + return binary.ReadVarint(r.base) +} + +func (r *Reader) ReadI8() (int8, error) { + b, err := r.base.ReadByte() + return int8(b), err +} + +func (r *Reader) ReadI16() (int16, error) { + var i int16 + if _, err := io.ReadAtLeast(r.base, r.scratch[:2], 2); err != nil { + return i, err + } + return int16(binary.LittleEndian.Uint16(r.scratch[:])), nil +} + +func (r *Reader) ReadI32() (int32, error) { + var i int32 + if _, err := io.ReadAtLeast(r.base, r.scratch[:4], 4); err != nil { + return i, err + } + return int32(binary.LittleEndian.Uint32(r.scratch[:])), nil +} + +func (r *Reader) ReadI64() (int64, error) { + var i int64 + if _, err := io.ReadAtLeast(r.base, r.scratch[:], 8); err != nil { + return i, err + } + return int64(binary.LittleEndian.Uint64(r.scratch[:])), nil +} + +func (r *Reader) ReadF32() (float32, error) { + u, err := r.ReadU32() + f := math.Float32frombits(u) + if math.IsNaN(float64(f)) { + return 0.0, fmt.Errorf("NaN is not permitted in BARE floats") + } + return f, err +} + +func (r *Reader) ReadF64() (float64, error) { + u, err := r.ReadU64() + f := math.Float64frombits(u) + if math.IsNaN(f) { + return 0.0, fmt.Errorf("NaN is not permitted in BARE floats") + } + return f, err +} + +func (r *Reader) ReadBool() (bool, error) { + b, err := r.ReadU8() + if err != nil { + return false, err + } + + if b > 1 { + return false, fmt.Errorf("Invalid bool value: %#x", b) + } + + return b == 1, nil +} + +func (r *Reader) ReadString() (string, error) { + buf, err := r.ReadData() + if err != nil { + return "", err + } + if !utf8.Valid(buf) { + return "", ErrInvalidStr + } + return misc.BytesToString(buf), nil +} + +// Reads a fixed amount of arbitrary data, defined by the length of the slice. +func (r *Reader) ReadDataFixed(dest []byte) error { + var amt int = 0 + for amt < len(dest) { + n, err := r.base.Read(dest[amt:]) + if err != nil { + return err + } + amt += n + } + return nil +} + +// Reads arbitrary data whose length is read from the message. +func (r *Reader) ReadData() ([]byte, error) { + l, err := r.ReadUint() + if err != nil { + return nil, err + } + if l >= maxUnmarshalBytes { + return nil, ErrLimitExceeded + } + buf := make([]byte, l) + var amt uint64 = 0 + for amt < l { + n, err := r.base.Read(buf[amt:]) + if err != nil { + return nil, err + } + amt += uint64(n) + } + return buf, nil +} diff --git a/forged/internal/common/bare/unions.go b/forged/internal/common/bare/unions.go new file mode 100644 index 0000000..1020fa0 --- /dev/null +++ b/forged/internal/common/bare/unions.go @@ -0,0 +1,81 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright (c) 2025 Drew Devault <https://drewdevault.com> + +package bare + +import ( + "fmt" + "reflect" +) + +// Any type which is a union member must implement this interface. You must +// also call RegisterUnion for go-bare to marshal or unmarshal messages which +// utilize your union type. +type Union interface { + IsUnion() +} + +type UnionTags struct { + iface reflect.Type + tags map[reflect.Type]uint64 + types map[uint64]reflect.Type +} + +var ( + unionInterface = reflect.TypeOf((*Union)(nil)).Elem() + unionRegistry map[reflect.Type]*UnionTags +) + +func init() { + unionRegistry = make(map[reflect.Type]*UnionTags) +} + +// Registers a union type in this context. Pass the union interface and the +// list of types associated with it, sorted ascending by their union tag. +func RegisterUnion(iface interface{}) *UnionTags { + ity := reflect.TypeOf(iface).Elem() + if _, ok := unionRegistry[ity]; ok { + panic(fmt.Errorf("Type %s has already been registered", ity.Name())) + } + + if !ity.Implements(reflect.TypeOf((*Union)(nil)).Elem()) { + panic(fmt.Errorf("Type %s does not implement bare.Union", ity.Name())) + } + + utypes := &UnionTags{ + iface: ity, + tags: make(map[reflect.Type]uint64), + types: make(map[uint64]reflect.Type), + } + unionRegistry[ity] = utypes + return utypes +} + +func (ut *UnionTags) Member(t interface{}, tag uint64) *UnionTags { + ty := reflect.TypeOf(t) + if !ty.AssignableTo(ut.iface) { + panic(fmt.Errorf("Type %s does not implement interface %s", + ty.Name(), ut.iface.Name())) + } + if _, ok := ut.tags[ty]; ok { + panic(fmt.Errorf("Type %s is already registered for union %s", + ty.Name(), ut.iface.Name())) + } + if _, ok := ut.types[tag]; ok { + panic(fmt.Errorf("Tag %d is already registered for union %s", + tag, ut.iface.Name())) + } + ut.tags[ty] = tag + ut.types[tag] = ty + return ut +} + +func (ut *UnionTags) TagFor(v interface{}) (uint64, bool) { + tag, ok := ut.tags[reflect.TypeOf(v)] + return tag, ok +} + +func (ut *UnionTags) TypeFor(tag uint64) (reflect.Type, bool) { + t, ok := ut.types[tag] + return t, ok +} diff --git a/forged/internal/common/bare/unmarshal.go b/forged/internal/common/bare/unmarshal.go new file mode 100644 index 0000000..d55f32c --- /dev/null +++ b/forged/internal/common/bare/unmarshal.go @@ -0,0 +1,362 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright (c) 2025 Drew Devault <https://drewdevault.com> + +package bare + +import ( + "bytes" + "errors" + "fmt" + "io" + "reflect" + "sync" +) + +// A type which implements this interface will be responsible for unmarshaling +// itself when encountered. +type Unmarshalable interface { + Unmarshal(r *Reader) error +} + +// Unmarshals a BARE message into val, which must be a pointer to a value of +// the message type. +func Unmarshal(data []byte, val interface{}) error { + b := bytes.NewReader(data) + r := NewReader(b) + return UnmarshalBareReader(r, val) +} + +// Unmarshals a BARE message into value (val, which must be a pointer), from a +// reader. See Unmarshal for details. +func UnmarshalReader(r io.Reader, val interface{}) error { + r = newLimitedReader(r) + return UnmarshalBareReader(NewReader(r), val) +} + +type decodeFunc func(r *Reader, v reflect.Value) error + +var decodeFuncCache sync.Map // map[reflect.Type]decodeFunc + +func UnmarshalBareReader(r *Reader, val interface{}) error { + t := reflect.TypeOf(val) + v := reflect.ValueOf(val) + if t.Kind() != reflect.Ptr { + return errors.New("Expected val to be pointer type") + } + + return getDecoder(t.Elem())(r, v.Elem()) +} + +// get decoder from cache +func getDecoder(t reflect.Type) decodeFunc { + if f, ok := decodeFuncCache.Load(t); ok { + return f.(decodeFunc) + } + + f := decoderFunc(t) + decodeFuncCache.Store(t, f) + return f +} + +var unmarshalableInterface = reflect.TypeOf((*Unmarshalable)(nil)).Elem() + +func decoderFunc(t reflect.Type) decodeFunc { + if reflect.PointerTo(t).Implements(unmarshalableInterface) { + return func(r *Reader, v reflect.Value) error { + uv := v.Addr().Interface().(Unmarshalable) + return uv.Unmarshal(r) + } + } + + if t.Kind() == reflect.Interface && t.Implements(unionInterface) { + return decodeUnion(t) + } + + switch t.Kind() { + case reflect.Ptr: + return decodeOptional(t.Elem()) + case reflect.Struct: + return decodeStruct(t) + case reflect.Array: + return decodeArray(t) + case reflect.Slice: + return decodeSlice(t) + case reflect.Map: + return decodeMap(t) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return decodeUint + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return decodeInt + case reflect.Float32, reflect.Float64: + return decodeFloat + case reflect.Bool: + return decodeBool + case reflect.String: + return decodeString + } + + return func(r *Reader, v reflect.Value) error { + return &UnsupportedTypeError{v.Type()} + } +} + +func decodeOptional(t reflect.Type) decodeFunc { + return func(r *Reader, v reflect.Value) error { + s, err := r.ReadU8() + if err != nil { + return err + } + + if s > 1 { + return fmt.Errorf("Invalid optional value: %#x", s) + } + + if s == 0 { + return nil + } + + v.Set(reflect.New(t)) + return getDecoder(t)(r, v.Elem()) + } +} + +func decodeStruct(t reflect.Type) decodeFunc { + n := t.NumField() + decoders := make([]decodeFunc, n) + for i := 0; i < n; i++ { + field := t.Field(i) + if field.Tag.Get("bare") == "-" { + continue + } + decoders[i] = getDecoder(field.Type) + } + + return func(r *Reader, v reflect.Value) error { + for i := 0; i < n; i++ { + if decoders[i] == nil { + continue + } + err := decoders[i](r, v.Field(i)) + if err != nil { + return err + } + } + return nil + } +} + +func decodeArray(t reflect.Type) decodeFunc { + f := getDecoder(t.Elem()) + len := t.Len() + + return func(r *Reader, v reflect.Value) error { + for i := 0; i < len; i++ { + err := f(r, v.Index(i)) + if err != nil { + return err + } + } + return nil + } +} + +func decodeSlice(t reflect.Type) decodeFunc { + elem := t.Elem() + f := getDecoder(elem) + + return func(r *Reader, v reflect.Value) error { + len, err := r.ReadUint() + if err != nil { + return err + } + + if len > maxArrayLength { + return fmt.Errorf("Array length %d exceeds configured limit of %d", len, maxArrayLength) + } + + v.Set(reflect.MakeSlice(t, int(len), int(len))) + + for i := 0; i < int(len); i++ { + if err := f(r, v.Index(i)); err != nil { + return err + } + } + return nil + } +} + +func decodeMap(t reflect.Type) decodeFunc { + keyType := t.Key() + keyf := getDecoder(keyType) + + valueType := t.Elem() + valf := getDecoder(valueType) + + return func(r *Reader, v reflect.Value) error { + size, err := r.ReadUint() + if err != nil { + return err + } + + if size > maxMapSize { + return fmt.Errorf("Map size %d exceeds configured limit of %d", size, maxMapSize) + } + + v.Set(reflect.MakeMapWithSize(t, int(size))) + + key := reflect.New(keyType).Elem() + value := reflect.New(valueType).Elem() + + for i := uint64(0); i < size; i++ { + if err := keyf(r, key); err != nil { + return err + } + + if v.MapIndex(key).Kind() > reflect.Invalid { + return fmt.Errorf("Encountered duplicate map key: %v", key.Interface()) + } + + if err := valf(r, value); err != nil { + return err + } + + v.SetMapIndex(key, value) + } + return nil + } +} + +func decodeUnion(t reflect.Type) decodeFunc { + ut, ok := unionRegistry[t] + if !ok { + return func(r *Reader, v reflect.Value) error { + return fmt.Errorf("Union type %s is not registered", t.Name()) + } + } + + decoders := make(map[uint64]decodeFunc) + for tag, t := range ut.types { + t := t + f := getDecoder(t) + + decoders[tag] = func(r *Reader, v reflect.Value) error { + nv := reflect.New(t) + if err := f(r, nv.Elem()); err != nil { + return err + } + + v.Set(nv) + return nil + } + } + + return func(r *Reader, v reflect.Value) error { + tag, err := r.ReadUint() + if err != nil { + return err + } + + if f, ok := decoders[tag]; ok { + return f(r, v) + } + + return fmt.Errorf("Invalid union tag %d for type %s", tag, t.Name()) + } +} + +func decodeUint(r *Reader, v reflect.Value) error { + var err error + switch getIntKind(v.Type()) { + case reflect.Uint: + var u uint64 + u, err = r.ReadUint() + v.SetUint(u) + + case reflect.Uint8: + var u uint8 + u, err = r.ReadU8() + v.SetUint(uint64(u)) + + case reflect.Uint16: + var u uint16 + u, err = r.ReadU16() + v.SetUint(uint64(u)) + case reflect.Uint32: + var u uint32 + u, err = r.ReadU32() + v.SetUint(uint64(u)) + + case reflect.Uint64: + var u uint64 + u, err = r.ReadU64() + v.SetUint(uint64(u)) + + default: + panic("not an uint") + } + + return err +} + +func decodeInt(r *Reader, v reflect.Value) error { + var err error + switch getIntKind(v.Type()) { + case reflect.Int: + var i int64 + i, err = r.ReadInt() + v.SetInt(i) + + case reflect.Int8: + var i int8 + i, err = r.ReadI8() + v.SetInt(int64(i)) + + case reflect.Int16: + var i int16 + i, err = r.ReadI16() + v.SetInt(int64(i)) + case reflect.Int32: + var i int32 + i, err = r.ReadI32() + v.SetInt(int64(i)) + + case reflect.Int64: + var i int64 + i, err = r.ReadI64() + v.SetInt(int64(i)) + + default: + panic("not an int") + } + + return err +} + +func decodeFloat(r *Reader, v reflect.Value) error { + var err error + switch v.Type().Kind() { + case reflect.Float32: + var f float32 + f, err = r.ReadF32() + v.SetFloat(float64(f)) + case reflect.Float64: + var f float64 + f, err = r.ReadF64() + v.SetFloat(f) + default: + panic("not a float") + } + return err +} + +func decodeBool(r *Reader, v reflect.Value) error { + b, err := r.ReadBool() + v.SetBool(b) + return err +} + +func decodeString(r *Reader, v reflect.Value) error { + s, err := r.ReadString() + v.SetString(s) + return err +} diff --git a/forged/internal/common/bare/varint.go b/forged/internal/common/bare/varint.go new file mode 100644 index 0000000..a185ac8 --- /dev/null +++ b/forged/internal/common/bare/varint.go @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright (c) 2025 Drew Devault <https://drewdevault.com> + +package bare + +import ( + "reflect" +) + +// Int is a variable-length encoded signed integer. +type Int int64 + +// Uint is a variable-length encoded unsigned integer. +type Uint uint64 + +var ( + intType = reflect.TypeOf(Int(0)) + uintType = reflect.TypeOf(Uint(0)) +) + +func getIntKind(t reflect.Type) reflect.Kind { + switch t { + case intType: + return reflect.Int + case uintType: + return reflect.Uint + default: + return t.Kind() + } +} diff --git a/forged/internal/common/bare/writer.go b/forged/internal/common/bare/writer.go new file mode 100644 index 0000000..80cd7e2 --- /dev/null +++ b/forged/internal/common/bare/writer.go @@ -0,0 +1,121 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright (c) 2025 Drew Devault <https://drewdevault.com> + +package bare + +import ( + "encoding/binary" + "fmt" + "io" + "math" + + "go.lindenii.runxiyu.org/forge/forged/internal/common/misc" +) + +// A Writer for BARE primitive types. +type Writer struct { + base io.Writer + scratch [binary.MaxVarintLen64]byte +} + +// Returns a new BARE primitive writer wrapping the given io.Writer. +func NewWriter(base io.Writer) *Writer { + return &Writer{base: base} +} + +func (w *Writer) WriteUint(i uint64) error { + n := binary.PutUvarint(w.scratch[:], i) + _, err := w.base.Write(w.scratch[:n]) + return err +} + +func (w *Writer) WriteU8(i uint8) error { + return binary.Write(w.base, binary.LittleEndian, i) +} + +func (w *Writer) WriteU16(i uint16) error { + return binary.Write(w.base, binary.LittleEndian, i) +} + +func (w *Writer) WriteU32(i uint32) error { + return binary.Write(w.base, binary.LittleEndian, i) +} + +func (w *Writer) WriteU64(i uint64) error { + return binary.Write(w.base, binary.LittleEndian, i) +} + +func (w *Writer) WriteInt(i int64) error { + var buf [binary.MaxVarintLen64]byte + n := binary.PutVarint(buf[:], i) + _, err := w.base.Write(buf[:n]) + return err +} + +func (w *Writer) WriteI8(i int8) error { + return binary.Write(w.base, binary.LittleEndian, i) +} + +func (w *Writer) WriteI16(i int16) error { + return binary.Write(w.base, binary.LittleEndian, i) +} + +func (w *Writer) WriteI32(i int32) error { + return binary.Write(w.base, binary.LittleEndian, i) +} + +func (w *Writer) WriteI64(i int64) error { + return binary.Write(w.base, binary.LittleEndian, i) +} + +func (w *Writer) WriteF32(f float32) error { + if math.IsNaN(float64(f)) { + return fmt.Errorf("NaN is not permitted in BARE floats") + } + return binary.Write(w.base, binary.LittleEndian, f) +} + +func (w *Writer) WriteF64(f float64) error { + if math.IsNaN(f) { + return fmt.Errorf("NaN is not permitted in BARE floats") + } + return binary.Write(w.base, binary.LittleEndian, f) +} + +func (w *Writer) WriteBool(b bool) error { + return binary.Write(w.base, binary.LittleEndian, b) +} + +func (w *Writer) WriteString(str string) error { + return w.WriteData(misc.StringToBytes(str)) +} + +// Writes a fixed amount of arbitrary data, defined by the length of the slice. +func (w *Writer) WriteDataFixed(data []byte) error { + var amt int = 0 + for amt < len(data) { + n, err := w.base.Write(data[amt:]) + if err != nil { + return err + } + amt += n + } + return nil +} + +// Writes arbitrary data whose length is encoded into the message. +func (w *Writer) WriteData(data []byte) error { + err := w.WriteUint(uint64(len(data))) + if err != nil { + return err + } + var amt int = 0 + for amt < len(data) { + n, err := w.base.Write(data[amt:]) + if err != nil { + return err + } + amt += n + } + return nil +} |