#region Copyright notice and license // Protocol Buffers - Google's data interchange format // Copyright 2008 Google Inc. All rights reserved. // https://developers.google.com/protocol-buffers/ // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions are // met: // // * Redistributions of source code must retain the above copyright // notice, this list of conditions and the following disclaimer. // * Redistributions in binary form must reproduce the above // copyright notice, this list of conditions and the following disclaimer // in the documentation and/or other materials provided with the // distribution. // * Neither the name of Google Inc. nor the names of its // contributors may be used to endorse or promote products derived from // this software without specific prior written permission. // // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #endregion using System; using System.Collections.Generic; using System.IO; namespace XGame.Framework.Network.Protobuf { /// /// Reads and decodes protocol message fields. /// /// /// /// This class is generally used by generated code to read appropriate /// primitives from the stream. It effectively encapsulates the lowest /// levels of protocol buffer format. /// /// /// Repeated fields and map fields are not handled by this class; use /// and to serialize such fields. /// /// public sealed class CodedInputStream : IDisposable { /// /// Whether to leave the underlying stream open when disposing of this stream. /// This is always true when there's no stream. /// private readonly bool leaveOpen; /// /// Buffer of data read from the stream or provided at construction time. /// private byte[] buffer; /// /// The index of the buffer at which we need to refill from the stream (if there is one). /// private int bufferSize; private int bufferSizeAfterLimit = 0; /// /// The position within the current buffer (i.e. the next byte to read) /// private int bufferPos = 0; /// /// The stream to read further input from, or null if the byte array buffer was provided /// directly on construction, with no further data available. /// private readonly Stream input; /// /// The last tag we read. 0 indicates we've read to the end of the stream /// (or haven't read anything yet). /// private uint lastTag = 0; /// /// The next tag, used to store the value read by PeekTag. /// private uint nextTag = 0; private bool hasNextTag = false; public const int DefaultRecursionLimit = 64; public const int DefaultSizeLimit = 64 << 20; // 64MB public const int BufferSize = 4096; /// /// The total number of bytes read before the current buffer. The /// total bytes read up to the current position can be computed as /// totalBytesRetired + bufferPos. /// private int totalBytesRetired = 0; /// /// The absolute position of the end of the current message. /// private int currentLimit = int.MaxValue; private int recursionDepth = 0; private readonly int recursionLimit; private readonly int sizeLimit; #region Construction // Note that the checks are performed such that we don't end up checking obviously-valid things // like non-null references for arrays we've just created. /// /// Creates a new CodedInputStream reading data from the given byte array. /// public CodedInputStream(byte[] buffer) : this(null, ProtoPreconditions.CheckNotNull(buffer, "buffer"), 0, buffer.Length, true) { } /// /// Creates a new that reads from the given byte array slice. /// public CodedInputStream(byte[] buffer, int offset, int length) : this(null, ProtoPreconditions.CheckNotNull(buffer, "buffer"), offset, offset + length, true) { if (offset < 0 || offset > buffer.Length) { throw new ArgumentOutOfRangeException("offset", "Offset must be within the buffer"); } if (length < 0 || offset + length > buffer.Length) { throw new ArgumentOutOfRangeException("length", "Length must be non-negative and within the buffer"); } } /// /// Creates a new reading data from the given stream, which will be disposed /// when the returned object is disposed. /// /// The stream to read from. public CodedInputStream(Stream input) : this(input, false) { } /// /// Creates a new reading data from the given stream. /// /// The stream to read from. /// true to leave open when the returned /// is disposed; false to dispose of the given stream when the /// returned object is disposed. public CodedInputStream(Stream input, bool leaveOpen) : this(ProtoPreconditions.CheckNotNull(input, "input"), new byte[BufferSize], 0, 0, leaveOpen) { } /// /// Creates a new CodedInputStream reading data from the given /// stream and buffer, using the default limits. /// public CodedInputStream(Stream input, byte[] buffer, int bufferPos, int bufferSize, bool leaveOpen) { this.input = input; this.buffer = buffer; this.bufferPos = bufferPos; this.bufferSize = bufferSize; this.sizeLimit = DefaultSizeLimit; this.recursionLimit = DefaultRecursionLimit; this.leaveOpen = leaveOpen; } public CodedInputStream() { this.sizeLimit = DefaultSizeLimit; this.recursionLimit = DefaultRecursionLimit; } public void Reset(byte[] buff, int bufferPos, int bufferSize) { this.buffer = buff; this.bufferPos = bufferPos; this.bufferSize = bufferSize; lastTag = 0; nextTag = 0; hasNextTag = false; totalBytesRetired = 0; bufferSizeAfterLimit = 0; currentLimit = int.MaxValue; recursionDepth = 0; } /// /// Creates a new CodedInputStream reading data from the given /// stream and buffer, using the specified limits. /// /// /// This chains to the version with the default limits instead of vice versa to avoid /// having to check that the default values are valid every time. /// public CodedInputStream(Stream input, byte[] buffer, int bufferPos, int bufferSize, int sizeLimit, int recursionLimit, bool leaveOpen) : this(input, buffer, bufferPos, bufferSize, leaveOpen) { if (sizeLimit <= 0) { throw new ArgumentOutOfRangeException("sizeLimit", "Size limit must be positive"); } if (recursionLimit <= 0) { throw new ArgumentOutOfRangeException("recursionLimit!", "Recursion limit must be positive"); } this.sizeLimit = sizeLimit; this.recursionLimit = recursionLimit; } #endregion /// /// Creates a with the specified size and recursion limits, reading /// from an input stream. /// /// /// This method exists separately from the constructor to reduce the number of constructor overloads. /// It is likely to be used considerably less frequently than the constructors, as the default limits /// are suitable for most use cases. /// /// The input stream to read from /// The total limit of data to read from the stream. /// The maximum recursion depth to allow while reading. /// A CodedInputStream reading from with the specified size /// and recursion limits. public static CodedInputStream CreateWithLimits(Stream input, int sizeLimit, int recursionLimit) { // Note: we may want an overload accepting leaveOpen return new CodedInputStream(input, new byte[BufferSize], 0, 0, sizeLimit, recursionLimit, false); } /// /// Returns the current position in the input stream, or the position in the input buffer /// public long Position { get { if (input != null) { return input.Position - ((bufferSize + bufferSizeAfterLimit) - bufferPos); } return bufferPos; } } /// /// Returns the last tag read, or 0 if no tags have been read or we've read beyond /// the end of the stream. /// public uint LastTag { get { return lastTag; } } /// /// Returns the size limit for this stream. /// /// /// This limit is applied when reading from the underlying stream, as a sanity check. It is /// not applied when reading from a byte array data source without an underlying stream. /// The default value is 64MB. /// /// /// The size limit. /// public int SizeLimit { get { return sizeLimit; } } /// /// Returns the recursion limit for this stream. This limit is applied whilst reading messages, /// to avoid maliciously-recursive data. /// /// /// The default limit is 64. /// /// /// The recursion limit for this stream. /// public int RecursionLimit { get { return recursionLimit; } } /// /// Disposes of this instance, potentially closing any underlying stream. /// /// /// As there is no flushing to perform here, disposing of a which /// was constructed with the leaveOpen option parameter set to true (or one which /// was constructed to read from a byte array) has no effect. /// public void Dispose() { if (!leaveOpen) { input.Dispose(); } } #region Validation /// /// Verifies that the last call to ReadTag() returned tag 0 - in other words, /// we've reached the end of the stream when we expected to. /// /// The /// tag read was not the one specified internal void CheckReadEndOfStreamTag() { if (lastTag != 0) { throw InvalidProtocolBufferException.MoreDataAvailable(); } } #endregion #region Reading of tags etc /// /// Peeks at the next field tag. This is like calling , but the /// tag is not consumed. (So a subsequent call to will return the /// same value.) /// public uint PeekTag() { if (hasNextTag) { return nextTag; } uint savedLast = lastTag; nextTag = ReadTag(); hasNextTag = true; lastTag = savedLast; // Undo the side effect of ReadTag return nextTag; } /// /// Reads a field tag, returning the tag of 0 for "end of stream". /// /// /// If this method returns 0, it doesn't necessarily mean the end of all /// the data in this CodedInputStream; it may be the end of the logical stream /// for an embedded message, for example. /// /// The next field tag, or 0 for end of stream. (0 is never a valid tag.) public uint ReadTag() { if (hasNextTag) { lastTag = nextTag; hasNextTag = false; return lastTag; } // Optimize for the incredibly common case of having at least two bytes left in the buffer, // and those two bytes being enough to get the tag. This will be true for fields up to 4095. if (bufferPos + 2 <= bufferSize) { int tmp = buffer[bufferPos++]; if (tmp < 128) { lastTag = (uint)tmp; } else { int result = tmp & 0x7f; if ((tmp = buffer[bufferPos++]) < 128) { result |= tmp << 7; lastTag = (uint)result; } else { // Nope, rewind and go the potentially slow route. bufferPos -= 2; lastTag = ReadRawVarint32(); } } } else { if (IsAtEnd) { lastTag = 0; return 0; // This is the only case in which we return 0. } lastTag = ReadRawVarint32(); } if (WireFormat.GetTagFieldNumber(lastTag) == 0) { // If we actually read a tag with a field of 0, that's not a valid tag. throw InvalidProtocolBufferException.InvalidTag(); } return lastTag; } /// /// Skips the data for the field with the tag we've just read. /// This should be called directly after , when /// the caller wishes to skip an unknown field. /// /// /// This method throws if the last-read tag was an end-group tag. /// If a caller wishes to skip a group, they should skip the whole group, by calling this method after reading the /// start-group tag. This behavior allows callers to call this method on any field they don't understand, correctly /// resulting in an error if an end-group tag has not been paired with an earlier start-group tag. /// /// The last tag was an end-group tag /// The last read operation read to the end of the logical stream public void SkipLastField() { if (lastTag == 0) { throw new InvalidOperationException("SkipLastField cannot be called at the end of a stream"); } switch (WireFormat.GetTagWireType(lastTag)) { case WireFormat.WireType.StartGroup: SkipGroup(lastTag); break; case WireFormat.WireType.EndGroup: throw new InvalidProtocolBufferException( "SkipLastField called on an end-group tag, indicating that the corresponding start-group was missing"); case WireFormat.WireType.Fixed32: ReadFixed32(); break; case WireFormat.WireType.Fixed64: ReadFixed64(); break; case WireFormat.WireType.LengthDelimited: var length = ReadLength(); SkipRawBytes(length); break; case WireFormat.WireType.Varint: ReadRawVarint32(); break; } } /// /// Skip a group. /// public void SkipGroup(uint startGroupTag) { // Note: Currently we expect this to be the way that groups are read. We could put the recursion // depth changes into the ReadTag method instead, potentially... recursionDepth++; if (recursionDepth >= recursionLimit) { throw InvalidProtocolBufferException.RecursionLimitExceeded(); } uint tag; while (true) { tag = ReadTag(); if (tag == 0) { throw InvalidProtocolBufferException.TruncatedMessage(); } // Can't call SkipLastField for this case- that would throw. if (WireFormat.GetTagWireType(tag) == WireFormat.WireType.EndGroup) { break; } // This recursion will allow us to handle nested groups. SkipLastField(); } int startField = WireFormat.GetTagFieldNumber(startGroupTag); int endField = WireFormat.GetTagFieldNumber(tag); if (startField != endField) { throw new InvalidProtocolBufferException( $"Mismatched end-group tag. Started with field {startField}; ended with field {endField}"); } recursionDepth--; } /// /// Reads a double field from the stream. /// public double ReadDouble() { return BitConverter.Int64BitsToDouble((long)ReadRawLittleEndian64()); } /// /// Reads a float field from the stream. /// public float ReadFloat() { if (BitConverter.IsLittleEndian && 4 <= bufferSize - bufferPos) { float ret = BitConverter.ToSingle(buffer, bufferPos); bufferPos += 4; return ret; } else { byte[] rawBytes = ReadRawBytes(4); if (!BitConverter.IsLittleEndian) { ByteArray.Reverse(rawBytes); } return BitConverter.ToSingle(rawBytes, 0); } } /// /// Reads a uint64 field from the stream. /// public ulong ReadUInt64() { return ReadRawVarint64(); } /// /// Reads an int64 field from the stream. /// public long ReadInt64() { return (long)ReadRawVarint64(); } /// /// Reads an int32 field from the stream. /// public int ReadInt32() { return (int)ReadRawVarint32(); } /// /// Reads a fixed64 field from the stream. /// public ulong ReadFixed64() { return ReadRawLittleEndian64(); } /// /// Reads a fixed32 field from the stream. /// public uint ReadFixed32() { return ReadRawLittleEndian32(); } /// /// Reads a bool field from the stream. /// public bool ReadBool() { return ReadRawVarint32() != 0; } /// /// Reads a string field from the stream. /// public string ReadString() { int length = ReadLength(); return ReadString(length); } public string ReadString(int length) { // No need to read any data for an empty string. if (length == 0) { return ""; } if (length <= bufferSize - bufferPos) { // Fast path: We already have the bytes in a contiguous buffer, so // just copy directly from it. String result = CodedOutputStream.Utf8Encoding.GetString(buffer, bufferPos, length); bufferPos += length; return result; } // Slow path: Build a byte array first then copy it. return CodedOutputStream.Utf8Encoding.GetString(ReadRawBytes(length), 0, length); } ///// ///// Reads an embedded message field value from the stream. ///// //public void ReadMessage(IMessage builder) //{ // int length = ReadLength(); // if (recursionDepth >= recursionLimit) // { // throw InvalidProtocolBufferException.RecursionLimitExceeded(); // } // int oldLimit = PushLimit(length); // ++recursionDepth; // builder.MergeFrom(this); // CheckReadEndOfStreamTag(); // // Check that we've read exactly as much data as expected. // if (!ReachedLimit) // { // throw InvalidProtocolBufferException.TruncatedMessage(); // } // --recursionDepth; // PopLimit(oldLimit); //} /// /// Reads an embedded message field value from the stream. /// public void ReadMessage(IMsgParser builder) { int length = ReadLength(); if (recursionDepth >= recursionLimit) { throw InvalidProtocolBufferException.RecursionLimitExceeded(); } int oldLimit = PushLimit(length); ++recursionDepth; builder.MergeFrom(this); CheckReadEndOfStreamTag(); // Check that we've read exactly as much data as expected. if (!ReachedLimit) { throw InvalidProtocolBufferException.TruncatedMessage(); } --recursionDepth; PopLimit(oldLimit); } /// /// Reads a bytes field value from the stream. /// public ByteString ReadBytes() { int length = ReadLength(); if (length <= bufferSize - bufferPos && length > 0) { // Fast path: We already have the bytes in a contiguous buffer, so // just copy directly from it. ByteString result = ByteString.CopyFrom(buffer, bufferPos, length); bufferPos += length; return result; } else { // Slow path: Build a byte array and attach it to a new ByteString. return ByteString.AttachBytes(ReadRawBytes(length)); } } /// /// Reads a uint32 field value from the stream. /// public uint ReadUInt32() { return ReadRawVarint32(); } /// /// Reads an enum field value from the stream. /// public int ReadEnum() { // Currently just a pass-through, but it's nice to separate it logically from WriteInt32. return (int)ReadRawVarint32(); } /// /// Reads an sfixed32 field value from the stream. /// public int ReadSFixed32() { return (int)ReadRawLittleEndian32(); } /// /// Reads an sfixed64 field value from the stream. /// public long ReadSFixed64() { return (long)ReadRawLittleEndian64(); } /// /// Reads an sint32 field value from the stream. /// public int ReadSInt32() { return DecodeZigZag32(ReadRawVarint32()); } /// /// Reads an sint64 field value from the stream. /// public long ReadSInt64() { return DecodeZigZag64(ReadRawVarint64()); } /// /// Reads a length for length-delimited data. /// /// /// This is internally just reading a varint, but this method exists /// to make the calling code clearer. /// public int ReadLength() { return (int)ReadRawVarint32(); } /// /// Peeks at the next tag in the stream. If it matches , /// the tag is consumed and the method returns true; otherwise, the /// stream is left in the original position and the method returns false. /// public bool MaybeConsumeTag(uint tag) { if (PeekTag() == tag) { hasNextTag = false; return true; } return false; } #endregion #region Underlying reading primitives /// /// Same code as ReadRawVarint32, but read each byte individually, checking for /// buffer overflow. /// private uint SlowReadRawVarint32() { int tmp = ReadRawByte(); if (tmp < 128) { return (uint)tmp; } int result = tmp & 0x7f; if ((tmp = ReadRawByte()) < 128) { result |= tmp << 7; } else { result |= (tmp & 0x7f) << 7; if ((tmp = ReadRawByte()) < 128) { result |= tmp << 14; } else { result |= (tmp & 0x7f) << 14; if ((tmp = ReadRawByte()) < 128) { result |= tmp << 21; } else { result |= (tmp & 0x7f) << 21; result |= (tmp = ReadRawByte()) << 28; if (tmp >= 128) { // Discard upper 32 bits. for (int i = 0; i < 5; i++) { if (ReadRawByte() < 128) { return (uint)result; } } throw InvalidProtocolBufferException.MalformedVarint(); } } } } return (uint)result; } /// /// Reads a raw Varint from the stream. If larger than 32 bits, discard the upper bits. /// This method is optimised for the case where we've got lots of data in the buffer. /// That means we can check the size just once, then just read directly from the buffer /// without constant rechecking of the buffer length. /// public uint ReadRawVarint32() { if (bufferPos + 5 > bufferSize) { return SlowReadRawVarint32(); } int tmp = buffer[bufferPos++]; if (tmp < 128) { return (uint)tmp; } int result = tmp & 0x7f; if ((tmp = buffer[bufferPos++]) < 128) { result |= tmp << 7; } else { result |= (tmp & 0x7f) << 7; if ((tmp = buffer[bufferPos++]) < 128) { result |= tmp << 14; } else { result |= (tmp & 0x7f) << 14; if ((tmp = buffer[bufferPos++]) < 128) { result |= tmp << 21; } else { result |= (tmp & 0x7f) << 21; result |= (tmp = buffer[bufferPos++]) << 28; if (tmp >= 128) { // Discard upper 32 bits. // Note that this has to use ReadRawByte() as we only ensure we've // got at least 5 bytes at the start of the method. This lets us // use the fast path in more cases, and we rarely hit this section of code. for (int i = 0; i < 5; i++) { if (ReadRawByte() < 128) { return (uint)result; } } throw InvalidProtocolBufferException.MalformedVarint(); } } } } return (uint)result; } /// /// Reads a varint from the input one byte at a time, so that it does not /// read any bytes after the end of the varint. If you simply wrapped the /// stream in a CodedInputStream and used ReadRawVarint32(Stream) /// then you would probably end up reading past the end of the varint since /// CodedInputStream buffers its input. /// /// /// public static uint ReadRawVarint32(Stream input) { int result = 0; int offset = 0; for (; offset < 32; offset += 7) { int b = input.ReadByte(); if (b == -1) { throw InvalidProtocolBufferException.TruncatedMessage(); } result |= (b & 0x7f) << offset; if ((b & 0x80) == 0) { return (uint)result; } } // Keep reading up to 64 bits. for (; offset < 64; offset += 7) { int b = input.ReadByte(); if (b == -1) { throw InvalidProtocolBufferException.TruncatedMessage(); } if ((b & 0x80) == 0) { return (uint)result; } } throw InvalidProtocolBufferException.MalformedVarint(); } /// /// Reads a raw varint from the stream. /// public ulong ReadRawVarint64() { int shift = 0; ulong result = 0; while (shift < 64) { byte b = ReadRawByte(); result |= (ulong)(b & 0x7F) << shift; if ((b & 0x80) == 0) { return result; } shift += 7; } throw InvalidProtocolBufferException.MalformedVarint(); } /// /// Reads a 32-bit little-endian integer from the stream. /// public uint ReadRawLittleEndian32() { uint b1 = ReadRawByte(); uint b2 = ReadRawByte(); uint b3 = ReadRawByte(); uint b4 = ReadRawByte(); return b1 | (b2 << 8) | (b3 << 16) | (b4 << 24); } /// /// Reads a 64-bit little-endian integer from the stream. /// public ulong ReadRawLittleEndian64() { ulong b1 = ReadRawByte(); ulong b2 = ReadRawByte(); ulong b3 = ReadRawByte(); ulong b4 = ReadRawByte(); ulong b5 = ReadRawByte(); ulong b6 = ReadRawByte(); ulong b7 = ReadRawByte(); ulong b8 = ReadRawByte(); return b1 | (b2 << 8) | (b3 << 16) | (b4 << 24) | (b5 << 32) | (b6 << 40) | (b7 << 48) | (b8 << 56); } /// /// Decode a 32-bit value with ZigZag encoding. /// /// /// ZigZag encodes signed integers into values that can be efficiently /// encoded with varint. (Otherwise, negative values must be /// sign-extended to 64 bits to be varint encoded, thus always taking /// 10 bytes on the wire.) /// public static int DecodeZigZag32(uint n) { return (int)(n >> 1) ^ -(int)(n & 1); } /// /// Decode a 32-bit value with ZigZag encoding. /// /// /// ZigZag encodes signed integers into values that can be efficiently /// encoded with varint. (Otherwise, negative values must be /// sign-extended to 64 bits to be varint encoded, thus always taking /// 10 bytes on the wire.) /// public static long DecodeZigZag64(ulong n) { return (long)(n >> 1) ^ -(long)(n & 1); } #endregion #region Internal reading and buffer management /// /// Sets currentLimit to (current position) + byteLimit. This is called /// when descending into a length-delimited embedded message. The previous /// limit is returned. /// /// The old limit. public int PushLimit(int byteLimit) { if (byteLimit < 0) { throw InvalidProtocolBufferException.NegativeSize(); } byteLimit += totalBytesRetired + bufferPos; int oldLimit = currentLimit; if (byteLimit > oldLimit) { throw InvalidProtocolBufferException.TruncatedMessage(); } currentLimit = byteLimit; RecomputeBufferSizeAfterLimit(); return oldLimit; } private void RecomputeBufferSizeAfterLimit() { bufferSize += bufferSizeAfterLimit; int bufferEnd = totalBytesRetired + bufferSize; if (bufferEnd > currentLimit) { // Limit is in current buffer. bufferSizeAfterLimit = bufferEnd - currentLimit; bufferSize -= bufferSizeAfterLimit; } else { bufferSizeAfterLimit = 0; } } /// /// Discards the current limit, returning the previous limit. /// public void PopLimit(int oldLimit) { currentLimit = oldLimit; RecomputeBufferSizeAfterLimit(); } /// /// Returns whether or not all the data before the limit has been read. /// /// public bool ReachedLimit { get { if (currentLimit == int.MaxValue) { return false; } int currentAbsolutePosition = totalBytesRetired + bufferPos; return currentAbsolutePosition >= currentLimit; } } /// /// Returns true if the stream has reached the end of the input. This is the /// case if either the end of the underlying input source has been reached or /// the stream has reached a limit created using PushLimit. /// public bool IsAtEnd { get { return bufferPos == bufferSize && !RefillBuffer(false); } } /// /// Called when buffer is empty to read more bytes from the /// input. If is true, RefillBuffer() gurantees that /// either there will be at least one byte in the buffer when it returns /// or it will throw an exception. If is false, /// RefillBuffer() returns false if no more bytes were available. /// /// /// private bool RefillBuffer(bool mustSucceed) { if (bufferPos < bufferSize) { throw new InvalidOperationException("RefillBuffer() called when buffer wasn't empty."); } if (totalBytesRetired + bufferSize == currentLimit) { // Oops, we hit a limit. if (mustSucceed) { throw InvalidProtocolBufferException.TruncatedMessage(); } else { return false; } } totalBytesRetired += bufferSize; bufferPos = 0; bufferSize = (input == null) ? 0 : input.Read(buffer, 0, buffer.Length); if (bufferSize < 0) { throw new InvalidOperationException("Stream.Read returned a negative count"); } if (bufferSize == 0) { if (mustSucceed) { throw InvalidProtocolBufferException.TruncatedMessage(); } else { return false; } } else { RecomputeBufferSizeAfterLimit(); int totalBytesRead = totalBytesRetired + bufferSize + bufferSizeAfterLimit; if (totalBytesRead > sizeLimit || totalBytesRead < 0) { throw InvalidProtocolBufferException.SizeLimitExceeded(); } return true; } } /// /// Read one byte from the input. /// /// /// the end of the stream or the current limit was reached /// public byte ReadRawByte() { if (bufferPos == bufferSize) { RefillBuffer(true); } return buffer[bufferPos++]; } /// /// Reads a fixed size of bytes from the input. /// /// /// the end of the stream or the current limit was reached /// public byte[] ReadRawBytes(int size) { if (size < 0) { throw InvalidProtocolBufferException.NegativeSize(); } if (totalBytesRetired + bufferPos + size > currentLimit) { // Read to the end of the stream (up to the current limit) anyway. SkipRawBytes(currentLimit - totalBytesRetired - bufferPos); // Then fail. throw InvalidProtocolBufferException.TruncatedMessage(); } if (size <= bufferSize - bufferPos) { // We have all the bytes we need already. byte[] bytes = new byte[size]; ByteArray.Copy(buffer, bufferPos, bytes, 0, size); bufferPos += size; return bytes; } else if (size < buffer.Length) { // Reading more bytes than are in the buffer, but not an excessive number // of bytes. We can safely allocate the resulting array ahead of time. // First copy what we have. byte[] bytes = new byte[size]; int pos = bufferSize - bufferPos; ByteArray.Copy(buffer, bufferPos, bytes, 0, pos); bufferPos = bufferSize; // We want to use RefillBuffer() and then copy from the buffer into our // byte array rather than reading directly into our byte array because // the input may be unbuffered. RefillBuffer(true); while (size - pos > bufferSize) { Buffer.BlockCopy(buffer, 0, bytes, pos, bufferSize); pos += bufferSize; bufferPos = bufferSize; RefillBuffer(true); } ByteArray.Copy(buffer, 0, bytes, pos, size - pos); bufferPos = size - pos; return bytes; } else { // The size is very large. For security reasons, we can't allocate the // entire byte array yet. The size comes directly from the input, so a // maliciously-crafted message could provide a bogus very large size in // order to trick the app into allocating a lot of memory. We avoid this // by allocating and reading only a small chunk at a time, so that the // malicious message must actually *be* extremely large to cause // problems. Meanwhile, we limit the allowed size of a message elsewhere. // Remember the buffer markers since we'll have to copy the bytes out of // it later. int originalBufferPos = bufferPos; int originalBufferSize = bufferSize; // Mark the current buffer consumed. totalBytesRetired += bufferSize; bufferPos = 0; bufferSize = 0; // Read all the rest of the bytes we need. int sizeLeft = size - (originalBufferSize - originalBufferPos); List chunks = new List(); while (sizeLeft > 0) { byte[] chunk = new byte[Math.Min(sizeLeft, buffer.Length)]; int pos = 0; while (pos < chunk.Length) { int n = (input == null) ? -1 : input.Read(chunk, pos, chunk.Length - pos); if (n <= 0) { throw InvalidProtocolBufferException.TruncatedMessage(); } totalBytesRetired += n; pos += n; } sizeLeft -= chunk.Length; chunks.Add(chunk); } // OK, got everything. Now concatenate it all into one buffer. byte[] bytes = new byte[size]; // Start by copying the leftover bytes from this.buffer. int newPos = originalBufferSize - originalBufferPos; ByteArray.Copy(buffer, originalBufferPos, bytes, 0, newPos); // And now all the chunks. foreach (byte[] chunk in chunks) { Buffer.BlockCopy(chunk, 0, bytes, newPos, chunk.Length); newPos += chunk.Length; } // Done. return bytes; } } /// /// Reads and discards bytes. /// /// the end of the stream /// or the current limit was reached private void SkipRawBytes(int size) { if (size < 0) { throw InvalidProtocolBufferException.NegativeSize(); } if (totalBytesRetired + bufferPos + size > currentLimit) { // Read to the end of the stream anyway. SkipRawBytes(currentLimit - totalBytesRetired - bufferPos); // Then fail. throw InvalidProtocolBufferException.TruncatedMessage(); } if (size <= bufferSize - bufferPos) { // We have all the bytes we need already. bufferPos += size; } else { // Skipping more bytes than are in the buffer. First skip what we have. int pos = bufferSize - bufferPos; // ROK 5/7/2013 Issue #54: should retire all bytes in buffer (bufferSize) // totalBytesRetired += pos; totalBytesRetired += bufferSize; bufferPos = 0; bufferSize = 0; // Then skip directly from the InputStream for the rest. if (pos < size) { if (input == null) { throw InvalidProtocolBufferException.TruncatedMessage(); } SkipImpl(size - pos); totalBytesRetired += size - pos; } } } /// /// Abstraction of skipping to cope with streams which can't really skip. /// private void SkipImpl(int amountToSkip) { if (input.CanSeek) { long previousPosition = input.Position; input.Position += amountToSkip; if (input.Position != previousPosition + amountToSkip) { throw InvalidProtocolBufferException.TruncatedMessage(); } } else { byte[] skipBuffer = new byte[Math.Min(1024, amountToSkip)]; while (amountToSkip > 0) { int bytesRead = input.Read(skipBuffer, 0, Math.Min(skipBuffer.Length, amountToSkip)); if (bytesRead <= 0) { throw InvalidProtocolBufferException.TruncatedMessage(); } amountToSkip -= bytesRead; } } } #endregion } }